mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Refactor JIT for unary/binary/ternary ops (#1206)
* refactor unary/binary/ternary ops * get_primitive_string util ---------
This commit is contained in:
parent
de2b9e7d0a
commit
934683088e
@ -6,6 +6,17 @@
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#define BINARY_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this)); \
|
||||
}
|
||||
|
||||
#define BINARY_GPU_MULTI(func) \
|
||||
void func::eval_gpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
binary_op_gpu(inputs, outputs, get_primitive_string(this)); \
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
|
||||
@ -61,7 +72,8 @@ void binary_op_gpu_inplace(
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto kernel = get_binary_two_kernel(d, kernel_name, a, outputs[0]);
|
||||
auto kernel =
|
||||
get_binary_two_kernel(d, kernel_name, a.dtype(), outputs[0].dtype(), op);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
@ -188,7 +200,7 @@ void binary_op_gpu_inplace(
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto kernel = get_binary_kernel(d, kernel_name, a, out);
|
||||
auto kernel = get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
bool donate_a = a.data_shared_ptr() == nullptr;
|
||||
@ -259,102 +271,44 @@ void binary_op_gpu(
|
||||
binary_op_gpu(inputs, out, op, s);
|
||||
}
|
||||
|
||||
void Add::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "add");
|
||||
}
|
||||
|
||||
void ArcTan2::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "arctan2");
|
||||
}
|
||||
BINARY_GPU(Add)
|
||||
BINARY_GPU(ArcTan2)
|
||||
BINARY_GPU(Divide)
|
||||
BINARY_GPU_MULTI(DivMod)
|
||||
BINARY_GPU(Remainder)
|
||||
BINARY_GPU(Equal)
|
||||
BINARY_GPU(Greater)
|
||||
BINARY_GPU(GreaterEqual)
|
||||
BINARY_GPU(Less)
|
||||
BINARY_GPU(LessEqual)
|
||||
BINARY_GPU(LogicalAnd)
|
||||
BINARY_GPU(LogicalOr)
|
||||
BINARY_GPU(LogAddExp)
|
||||
BINARY_GPU(Maximum)
|
||||
BINARY_GPU(Minimum)
|
||||
BINARY_GPU(Multiply)
|
||||
BINARY_GPU(NotEqual)
|
||||
BINARY_GPU(Power)
|
||||
BINARY_GPU(Subtract)
|
||||
|
||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
binary_op_gpu(inputs, out, "bitwise_and");
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
binary_op_gpu(inputs, out, "bitwise_or");
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
binary_op_gpu(inputs, out, "bitwise_xor");
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
binary_op_gpu(inputs, out, "left_shift");
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
binary_op_gpu(inputs, out, "right_shift");
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "div");
|
||||
}
|
||||
|
||||
void DivMod::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
binary_op_gpu(inputs, outputs, "divmod");
|
||||
}
|
||||
|
||||
void Remainder::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "rem");
|
||||
}
|
||||
|
||||
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, equal_nan_ ? "naneq" : "eq");
|
||||
}
|
||||
|
||||
void Greater::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "ge");
|
||||
}
|
||||
|
||||
void GreaterEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "geq");
|
||||
}
|
||||
|
||||
void Less::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "le");
|
||||
}
|
||||
|
||||
void LessEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "leq");
|
||||
}
|
||||
|
||||
void LogicalAnd::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "land");
|
||||
}
|
||||
|
||||
void LogicalOr::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "lor");
|
||||
}
|
||||
|
||||
void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "lae");
|
||||
}
|
||||
|
||||
void Maximum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "max");
|
||||
}
|
||||
|
||||
void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "min");
|
||||
}
|
||||
|
||||
void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "mul");
|
||||
}
|
||||
|
||||
void NotEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "neq");
|
||||
}
|
||||
|
||||
void Power::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "pow");
|
||||
}
|
||||
|
||||
void Subtract::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "sub");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -391,16 +391,16 @@ void multi_upload_bluestein_fft(
|
||||
std::vector<int> rstrides(in.ndim(), 1);
|
||||
rstarts[axis] = in.shape(axis) - back_offset;
|
||||
rstrides[axis] = -1;
|
||||
unary_op_gpu({in}, conj_temp, "conj", s);
|
||||
unary_op_gpu({in}, conj_temp, "Conjugate", s);
|
||||
slice_gpu(in, slice_temp, rstarts, rstrides, s);
|
||||
concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s);
|
||||
} else if (inverse) {
|
||||
unary_op_gpu({in}, temp, "conj", s);
|
||||
unary_op_gpu({in}, temp, "Conjugate", s);
|
||||
} else {
|
||||
temp.copy_shared_buffer(in);
|
||||
}
|
||||
|
||||
binary_op_gpu({temp, w_k_broadcast}, temp1, "mul", s);
|
||||
binary_op_gpu({temp, w_k_broadcast}, temp1, "Multiply", s);
|
||||
|
||||
std::vector<std::pair<int, int>> pads;
|
||||
auto padded_shape = out.shape();
|
||||
@ -419,7 +419,7 @@ void multi_upload_bluestein_fft(
|
||||
/*inplace=*/false,
|
||||
s);
|
||||
|
||||
binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "mul", s);
|
||||
binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "Multiply", s);
|
||||
|
||||
fft_op(
|
||||
pad_temp,
|
||||
@ -437,7 +437,7 @@ void multi_upload_bluestein_fft(
|
||||
starts[axis] = plan.bluestein_n - offset - n;
|
||||
slice_gpu(pad_temp1, temp, starts, strides, s);
|
||||
|
||||
binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "mul", s);
|
||||
binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "Multiply", s);
|
||||
|
||||
if (real && !inverse) {
|
||||
std::vector<int> rstarts(in.ndim(), 0);
|
||||
@ -451,11 +451,11 @@ void multi_upload_bluestein_fft(
|
||||
copies.push_back(inv_n);
|
||||
|
||||
copy_gpu(temp1, temp_float, CopyType::General, s);
|
||||
binary_op_gpu({temp_float, inv_n}, out, "mul", s);
|
||||
binary_op_gpu({temp_float, inv_n}, out, "Multiply", s);
|
||||
} else if (inverse) {
|
||||
auto inv_n = array({1.0f / n}, {1}, complex64);
|
||||
unary_op_gpu({temp1}, temp, "conj", s);
|
||||
binary_op_gpu({temp, inv_n}, out, "mul", s);
|
||||
unary_op_gpu({temp1}, temp, "Conjugate", s);
|
||||
binary_op_gpu({temp, inv_n}, out, "Multiply", s);
|
||||
copies.push_back(inv_n);
|
||||
} else {
|
||||
out.copy_shared_buffer(temp1);
|
||||
|
@ -1,87 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view binary_kernels = R"(
|
||||
template [[host_name("ss{0}")]] [[kernel]]
|
||||
void binary_ss<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("vs{0}")]] [[kernel]]
|
||||
void binary_vs<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("sv{0}")]] [[kernel]]
|
||||
void binary_sv<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("vv{0}")]] [[kernel]]
|
||||
void binary_vv<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g4{0}")]] [[kernel]] void
|
||||
binary_g_nd<{1}, {2}, {3}, 4>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const int shape[4],
|
||||
constant const size_t a_strides[4],
|
||||
constant const size_t b_strides[4],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g5{0}")]] [[kernel]] void
|
||||
binary_g_nd<{1}, {2}, {3}, 5>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const int shape[5],
|
||||
constant const size_t a_strides[5],
|
||||
constant const size_t b_strides[5],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("g1{0}")]] [[kernel]] void
|
||||
binary_g_nd1<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g2{0}")]] [[kernel]] void
|
||||
binary_g_nd2<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g3{0}")]] [[kernel]] void
|
||||
binary_g_nd3<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("gn{0}")]] [[kernel]]
|
||||
void binary_g<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
)";
|
@ -1,98 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view binary_two_kernels = R"(
|
||||
template [[host_name("ss{0}")]] [[kernel]]
|
||||
void binary_ss<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("vs{0}")]] [[kernel]]
|
||||
void binary_vs<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("sv{0}")]] [[kernel]]
|
||||
void binary_sv<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("vv{0}")]] [[kernel]]
|
||||
void binary_vv<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g4{0}")]] [[kernel]] void
|
||||
binary_g_nd<{1}, {2}, {3}, 4>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const int shape[4],
|
||||
constant const size_t a_strides[4],
|
||||
constant const size_t b_strides[4],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g5{0}")]] [[kernel]] void
|
||||
binary_g_nd<{1}, {2}, {3}, 5>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const int shape[5],
|
||||
constant const size_t a_strides[5],
|
||||
constant const size_t b_strides[5],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("g1{0}")]] [[kernel]] void
|
||||
binary_g_nd1<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g2{0}")]] [[kernel]] void
|
||||
binary_g_nd2<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g3{0}")]] [[kernel]] void
|
||||
binary_g_nd3<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("gn{0}")]] [[kernel]]
|
||||
void binary_g<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
)";
|
@ -1,80 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view ternary_kernels = R"(
|
||||
template [[host_name("v_{0}")]] [[kernel]] void ternary_v<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g_{0}")]] [[kernel]] void ternary_g<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const size_t* c_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("g1_{0}")]] [[kernel]] void
|
||||
ternary_g_nd1<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const size_t& a_strides,
|
||||
constant const size_t& b_strides,
|
||||
constant const size_t& c_strides,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g2_{0}")]] [[kernel]] void
|
||||
ternary_g_nd2<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
constant const size_t c_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g3_{0}")]] [[kernel]] void
|
||||
ternary_g_nd3<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
constant const size_t c_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g4_{0}")]] [[kernel]] void
|
||||
ternary_g_nd<{1}, {2}, 4>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const int shape[4],
|
||||
constant const size_t a_strides[4],
|
||||
constant const size_t b_strides[4],
|
||||
constant const size_t c_strides[4],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g5_{0}")]] [[kernel]] void
|
||||
ternary_g_nd<{1}, {2}, 5>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const int shape[5],
|
||||
constant const size_t a_strides[5],
|
||||
constant const size_t b_strides[5],
|
||||
constant const size_t c_strides[5],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
)";
|
@ -1,16 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view unary_kernels = R"(
|
||||
template [[host_name("v{0}")]] [[kernel]] void unary_v<{1}, {2}>(
|
||||
device const {1}* in,
|
||||
device {1}* out,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g{0}")]] [[kernel]] void unary_g<{1}, {2}>(
|
||||
device const {1}* in,
|
||||
device {1}* out,
|
||||
device const int* in_shape,
|
||||
device const size_t* in_strides,
|
||||
device const int& ndim,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
)";
|
@ -1,9 +1,8 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <map>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/jit/arange.h"
|
||||
#include "mlx/backend/metal/jit/binary.h"
|
||||
#include "mlx/backend/metal/jit/binary_two.h"
|
||||
#include "mlx/backend/metal/jit/copy.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/reduce.h"
|
||||
@ -12,8 +11,6 @@
|
||||
#include "mlx/backend/metal/jit/sort.h"
|
||||
#include "mlx/backend/metal/jit/steel_conv.h"
|
||||
#include "mlx/backend/metal/jit/steel_gemm.h"
|
||||
#include "mlx/backend/metal/jit/ternary.h"
|
||||
#include "mlx/backend/metal/jit/unary.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
||||
@ -46,38 +43,76 @@ MTL::ComputePipelineState* get_arange_kernel(
|
||||
MTL::ComputePipelineState* get_unary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out) {
|
||||
Dtype out_type,
|
||||
const std::string op) {
|
||||
std::string lib_name = kernel_name.substr(1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
auto u_def = get_template_definition(
|
||||
"v" + lib_name, "unary_v", get_type_string(out_type), op);
|
||||
auto g_def = get_template_definition(
|
||||
"g" + lib_name, "unary_g", get_type_string(out_type), op);
|
||||
kernel_source << metal::utils() << metal::unary_ops() << metal::unary()
|
||||
<< fmt::format(
|
||||
unary_kernels,
|
||||
lib_name,
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
<< u_def << g_def;
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
void add_binary_kernels(
|
||||
const std::string lib_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op,
|
||||
std::ostringstream& kernel_source) {
|
||||
const std::map<std::string, std::string> kernel_types = {
|
||||
{"ss", "binary_ss"},
|
||||
{"vs", "binary_vs"},
|
||||
{"sv", "binary_sv"},
|
||||
{"vv", "binary_vv"},
|
||||
{"g1", "binary_g_nd1"},
|
||||
{"g2", "binary_g_nd2"},
|
||||
{"g3", "binary_g_nd3"},
|
||||
{"g4", "binary_g_nd"},
|
||||
{"g5", "binary_g_nd"},
|
||||
{"gn", "binary_g"},
|
||||
};
|
||||
for (auto [name, func] : kernel_types) {
|
||||
std::string template_def;
|
||||
if (name == "g4" || name == "g5") {
|
||||
int dim = std::stoi(name.substr(1));
|
||||
template_def = get_template_definition(
|
||||
name + lib_name,
|
||||
func,
|
||||
get_type_string(in_type),
|
||||
get_type_string(out_type),
|
||||
op,
|
||||
dim);
|
||||
} else {
|
||||
template_def = get_template_definition(
|
||||
name + lib_name,
|
||||
func,
|
||||
get_type_string(in_type),
|
||||
get_type_string(out_type),
|
||||
op);
|
||||
}
|
||||
kernel_source << template_def;
|
||||
}
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_binary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op) {
|
||||
std::string lib_name = kernel_name.substr(2);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::binary_ops() << metal::binary()
|
||||
<< fmt::format(
|
||||
binary_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
kernel_source << metal::utils() << metal::binary_ops() << metal::binary();
|
||||
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
@ -86,20 +121,16 @@ MTL::ComputePipelineState* get_binary_kernel(
|
||||
MTL::ComputePipelineState* get_binary_two_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op) {
|
||||
std::string lib_name = kernel_name.substr(2);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::binary_ops()
|
||||
<< metal::binary_two()
|
||||
<< fmt::format(
|
||||
binary_two_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
<< metal::binary_two();
|
||||
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
@ -108,17 +139,34 @@ MTL::ComputePipelineState* get_binary_two_kernel(
|
||||
MTL::ComputePipelineState* get_ternary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out) {
|
||||
Dtype type,
|
||||
const std::string op) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary()
|
||||
<< fmt::format(
|
||||
ternary_kernels,
|
||||
lib_name,
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
const std::map<std::string, std::string> kernel_types = {
|
||||
{"v", "ternary_v"},
|
||||
{"g", "ternary_g"},
|
||||
{"g1", "ternary_g_nd1"},
|
||||
{"g2", "ternary_g_nd2"},
|
||||
{"g3", "ternary_g_nd3"},
|
||||
{"g4", "ternary_g_nd"},
|
||||
{"g5", "ternary_g_nd"},
|
||||
};
|
||||
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary();
|
||||
for (auto [name, func] : kernel_types) {
|
||||
std::string template_def;
|
||||
if (name == "g4" || name == "g5") {
|
||||
int dim = std::stoi(name.substr(1));
|
||||
template_def = get_template_definition(
|
||||
name + "_" + lib_name, func, get_type_string(type), op, dim);
|
||||
} else {
|
||||
template_def = get_template_definition(
|
||||
name + "_" + lib_name, func, get_type_string(type), op);
|
||||
}
|
||||
kernel_source << template_def;
|
||||
}
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
|
@ -15,24 +15,28 @@ MTL::ComputePipelineState* get_arange_kernel(
|
||||
MTL::ComputePipelineState* get_unary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out);
|
||||
Dtype out_type,
|
||||
const std::string op);
|
||||
|
||||
MTL::ComputePipelineState* get_binary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out);
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op);
|
||||
|
||||
MTL::ComputePipelineState* get_binary_two_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out);
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op);
|
||||
|
||||
MTL::ComputePipelineState* get_ternary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out);
|
||||
Dtype type,
|
||||
const std::string op);
|
||||
|
||||
MTL::ComputePipelineState* get_copy_kernel(
|
||||
metal::Device& d,
|
||||
|
@ -4,148 +4,91 @@
|
||||
#include <metal_math>
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
|
||||
#define instantiate_binary(name, itype, otype, op, bopt) \
|
||||
template \
|
||||
[[host_name(name)]] [[kernel]] void binary_##bopt<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \
|
||||
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \
|
||||
instantiate_kernel("g4" #op #tname, binary_g_nd, itype, otype, op, 4) \
|
||||
instantiate_kernel("g5" #op #tname, binary_g_nd, itype, otype, op, 5)
|
||||
|
||||
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
|
||||
template [[host_name("g" #dims name)]] [[kernel]] void \
|
||||
binary_g_nd<itype, otype, op, dims>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const int shape[dims], \
|
||||
constant const size_t a_strides[dims], \
|
||||
constant const size_t b_strides[dims], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
#define instantiate_binary_integer(op) \
|
||||
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
||||
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
|
||||
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
|
||||
instantiate_binary_all(op, uint64, uint64_t, uint64_t) \
|
||||
instantiate_binary_all(op, int8, int8_t, int8_t) \
|
||||
instantiate_binary_all(op, int16, int16_t, int16_t) \
|
||||
instantiate_binary_all(op, int32, int32_t, int32_t) \
|
||||
instantiate_binary_all(op, int64, int64_t, int64_t)
|
||||
|
||||
#define instantiate_binary_g_nd(name, itype, otype, op) \
|
||||
template [[host_name("g1" name)]] [[kernel]] void \
|
||||
binary_g_nd1<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const size_t& a_stride, \
|
||||
constant const size_t& b_stride, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("g2" name)]] [[kernel]] void \
|
||||
binary_g_nd2<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const size_t a_strides[2], \
|
||||
constant const size_t b_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g3" name)]] [[kernel]] void \
|
||||
binary_g_nd3<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const size_t a_strides[3], \
|
||||
constant const size_t b_strides[3], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 4) \
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 5)
|
||||
#define instantiate_binary_float(op) \
|
||||
instantiate_binary_all(op, float16, half, half) \
|
||||
instantiate_binary_all(op, float32, float, float) \
|
||||
instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t)
|
||||
|
||||
#define instantiate_binary_g(name, itype, otype, op) \
|
||||
template [[host_name("gn" name)]] [[kernel]] void binary_g<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const int* shape, \
|
||||
constant const size_t* a_strides, \
|
||||
constant const size_t* b_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
#define instantiate_binary_types(op) \
|
||||
instantiate_binary_all(op, bool_, bool, bool) \
|
||||
instantiate_binary_integer(op) \
|
||||
instantiate_binary_all(op, complex64, complex64_t, complex64_t) \
|
||||
instantiate_binary_float(op)
|
||||
|
||||
#define instantiate_binary_all(name, tname, itype, otype, op) \
|
||||
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
|
||||
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
|
||||
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
|
||||
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
|
||||
instantiate_binary_g(#name #tname, itype, otype, op) \
|
||||
instantiate_binary_g_nd(#name #tname, itype, otype, op)
|
||||
#define instantiate_binary_types_bool(op) \
|
||||
instantiate_binary_all(op, bool_, bool, bool) \
|
||||
instantiate_binary_all(op, uint8, uint8_t, bool) \
|
||||
instantiate_binary_all(op, uint16, uint16_t, bool) \
|
||||
instantiate_binary_all(op, uint32, uint32_t, bool) \
|
||||
instantiate_binary_all(op, uint64, uint64_t, bool) \
|
||||
instantiate_binary_all(op, int8, int8_t, bool) \
|
||||
instantiate_binary_all(op, int16, int16_t, bool) \
|
||||
instantiate_binary_all(op, int32, int32_t, bool) \
|
||||
instantiate_binary_all(op, int64, int64_t, bool) \
|
||||
instantiate_binary_all(op, float16, half, bool) \
|
||||
instantiate_binary_all(op, float32, float, bool) \
|
||||
instantiate_binary_all(op, bfloat16, bfloat16_t, bool) \
|
||||
instantiate_binary_all(op, complex64, complex64_t, bool)
|
||||
|
||||
#define instantiate_binary_integer(name, op) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
|
||||
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \
|
||||
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \
|
||||
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
|
||||
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
|
||||
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
|
||||
instantiate_binary_all(name, int64, int64_t, int64_t, op)
|
||||
|
||||
#define instantiate_binary_float(name, op) \
|
||||
instantiate_binary_all(name, float16, half, half, op) \
|
||||
instantiate_binary_all(name, float32, float, float, op) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
|
||||
|
||||
#define instantiate_binary_types(name, op) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||
instantiate_binary_integer(name, op) \
|
||||
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
|
||||
instantiate_binary_float(name, op)
|
||||
|
||||
#define instantiate_binary_types_bool(name, op) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, bool, op) \
|
||||
instantiate_binary_all(name, uint32, uint32_t, bool, op) \
|
||||
instantiate_binary_all(name, uint64, uint64_t, bool, op) \
|
||||
instantiate_binary_all(name, int8, int8_t, bool, op) \
|
||||
instantiate_binary_all(name, int16, int16_t, bool, op) \
|
||||
instantiate_binary_all(name, int32, int32_t, bool, op) \
|
||||
instantiate_binary_all(name, int64, int64_t, bool, op) \
|
||||
instantiate_binary_all(name, float16, half, bool, op) \
|
||||
instantiate_binary_all(name, float32, float, bool, op) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \
|
||||
instantiate_binary_all(name, complex64, complex64_t, bool, op)
|
||||
|
||||
instantiate_binary_types(add, Add)
|
||||
instantiate_binary_types(div, Divide)
|
||||
instantiate_binary_types_bool(eq, Equal)
|
||||
instantiate_binary_types_bool(ge, Greater)
|
||||
instantiate_binary_types_bool(geq, GreaterEqual)
|
||||
instantiate_binary_types_bool(le, Less)
|
||||
instantiate_binary_types_bool(leq, LessEqual)
|
||||
instantiate_binary_types_bool(neq, NotEqual)
|
||||
instantiate_binary_float(lae, LogAddExp)
|
||||
instantiate_binary_types(max, Maximum)
|
||||
instantiate_binary_types(min, Minimum)
|
||||
instantiate_binary_types(mul, Multiply)
|
||||
instantiate_binary_types(sub, Subtract)
|
||||
instantiate_binary_types(pow, Power)
|
||||
instantiate_binary_types(rem, Remainder)
|
||||
instantiate_binary_float(arctan2, ArcTan2)
|
||||
instantiate_binary_types(Add)
|
||||
instantiate_binary_types(Divide)
|
||||
instantiate_binary_types_bool(Equal)
|
||||
instantiate_binary_types_bool(Greater)
|
||||
instantiate_binary_types_bool(GreaterEqual)
|
||||
instantiate_binary_types_bool(Less)
|
||||
instantiate_binary_types_bool(LessEqual)
|
||||
instantiate_binary_types_bool(NotEqual)
|
||||
instantiate_binary_float(LogAddExp)
|
||||
instantiate_binary_types(Maximum)
|
||||
instantiate_binary_types(Minimum)
|
||||
instantiate_binary_types(Multiply)
|
||||
instantiate_binary_types(Subtract)
|
||||
instantiate_binary_types(Power)
|
||||
instantiate_binary_types(Remainder)
|
||||
instantiate_binary_float(ArcTan2)
|
||||
|
||||
// NaNEqual only needed for floating point types with boolean output
|
||||
instantiate_binary_all(naneq, float16, half, bool, NaNEqual)
|
||||
instantiate_binary_all(naneq, float32, float, bool, NaNEqual)
|
||||
instantiate_binary_all(naneq, bfloat16, bfloat16_t, bool, NaNEqual)
|
||||
instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual)
|
||||
instantiate_binary_all(NaNEqual, float16, half, bool)
|
||||
instantiate_binary_all(NaNEqual, float32, float, bool)
|
||||
instantiate_binary_all(NaNEqual, bfloat16, bfloat16_t, bool)
|
||||
instantiate_binary_all(NaNEqual, complex64, complex64_t, bool)
|
||||
|
||||
instantiate_binary_all(lor, bool_, bool, bool, LogicalOr)
|
||||
instantiate_binary_all(land, bool_, bool, bool, LogicalAnd)
|
||||
instantiate_binary_all(LogicalOr, bool_, bool, bool)
|
||||
instantiate_binary_all(LogicalAnd, bool_, bool, bool)
|
||||
|
||||
// Bitwise ops only need integer types and bool (except for l/r shift)
|
||||
instantiate_binary_integer(bitwise_and, BitwiseAnd)
|
||||
instantiate_binary_all(bitwise_and, bool_, bool, bool, BitwiseAnd)
|
||||
instantiate_binary_integer(bitwise_or, BitwiseOr)
|
||||
instantiate_binary_all(bitwise_or, bool_, bool, bool, BitwiseOr)
|
||||
instantiate_binary_integer(bitwise_xor, BitwiseXor)
|
||||
instantiate_binary_all(bitwise_xor, bool_, bool, bool, BitwiseXor)
|
||||
instantiate_binary_integer(left_shift, LeftShift)
|
||||
instantiate_binary_integer(right_shift, RightShift) // clang-format on
|
||||
instantiate_binary_integer(BitwiseAnd)
|
||||
instantiate_binary_all(BitwiseAnd, bool_, bool, bool)
|
||||
instantiate_binary_integer(BitwiseOr)
|
||||
instantiate_binary_all(BitwiseOr, bool_, bool, bool)
|
||||
instantiate_binary_integer(BitwiseXor)
|
||||
instantiate_binary_all(BitwiseXor, bool_, bool, bool)
|
||||
instantiate_binary_integer(LeftShift)
|
||||
instantiate_binary_integer(RightShift) // clang-format on
|
||||
|
@ -7,99 +7,34 @@
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/binary_two.h"
|
||||
|
||||
#define instantiate_binary(name, itype, otype, op, bopt) \
|
||||
template [[host_name(name)]] [[kernel]] void \
|
||||
binary_##bopt<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \
|
||||
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \
|
||||
instantiate_kernel("g4" #op #tname, binary_g_nd, itype, otype, op, 4) \
|
||||
instantiate_kernel("g5" #op #tname, binary_g_nd, itype, otype, op, 5)
|
||||
|
||||
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
|
||||
template [[host_name("g" #dims name)]] [[kernel]] void \
|
||||
binary_g_nd<itype, otype, op, dims>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const int shape[dims], \
|
||||
constant const size_t a_strides[dims], \
|
||||
constant const size_t b_strides[dims], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
#define instantiate_binary_float(op) \
|
||||
instantiate_binary_all(op, float16, half, half) \
|
||||
instantiate_binary_all(op, float32, float, float) \
|
||||
instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t)
|
||||
|
||||
#define instantiate_binary_g_nd(name, itype, otype, op) \
|
||||
template [[host_name("g1" name)]] [[kernel]] void \
|
||||
binary_g_nd1<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const size_t& a_stride, \
|
||||
constant const size_t& b_stride, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("g2" name)]] [[kernel]] void \
|
||||
binary_g_nd2<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const size_t a_strides[2], \
|
||||
constant const size_t b_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g3" name)]] [[kernel]] void \
|
||||
binary_g_nd3<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const size_t a_strides[3], \
|
||||
constant const size_t b_strides[3], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 4) \
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 5)
|
||||
#define instantiate_binary_types(op) \
|
||||
instantiate_binary_all(op, bool_, bool, bool) \
|
||||
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
||||
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
|
||||
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
|
||||
instantiate_binary_all(op, uint64, uint64_t, uint64_t) \
|
||||
instantiate_binary_all(op, int8, int8_t, int8_t) \
|
||||
instantiate_binary_all(op, int16, int16_t, int16_t) \
|
||||
instantiate_binary_all(op, int32, int32_t, int32_t) \
|
||||
instantiate_binary_all(op, int64, int64_t, int64_t) \
|
||||
instantiate_binary_all(op, complex64, complex64_t, complex64_t) \
|
||||
instantiate_binary_float(op)
|
||||
|
||||
#define instantiate_binary_g(name, itype, otype, op) \
|
||||
template [[host_name("gn" name)]] [[kernel]] void \
|
||||
binary_g<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const int* shape, \
|
||||
constant const size_t* a_strides, \
|
||||
constant const size_t* b_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_binary_all(name, tname, itype, otype, op) \
|
||||
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
|
||||
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
|
||||
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
|
||||
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
|
||||
instantiate_binary_g(#name #tname, itype, otype, op) \
|
||||
instantiate_binary_g_nd(#name #tname, itype, otype, op)
|
||||
|
||||
#define instantiate_binary_float(name, op) \
|
||||
instantiate_binary_all(name, float16, half, half, op) \
|
||||
instantiate_binary_all(name, float32, float, float, op) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
|
||||
|
||||
#define instantiate_binary_types(name, op) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
|
||||
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \
|
||||
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \
|
||||
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
|
||||
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
|
||||
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
|
||||
instantiate_binary_all(name, int64, int64_t, int64_t, op) \
|
||||
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
|
||||
instantiate_binary_float(name, op)
|
||||
|
||||
instantiate_binary_types(divmod, DivMod) // clang-format on
|
||||
instantiate_binary_types(DivMod) // clang-format on
|
||||
|
@ -9,96 +9,28 @@
|
||||
#include "mlx/backend/metal/kernels/ternary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/ternary.h"
|
||||
|
||||
#define instantiate_ternary_v(name, type, op) \
|
||||
template [[host_name("v_" name)]] [[kernel]] void ternary_v<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
#define instantiate_ternary_all(op, tname, type) \
|
||||
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
|
||||
instantiate_kernel("g_" #op #tname, ternary_g, type, op) \
|
||||
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
|
||||
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \
|
||||
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op) \
|
||||
instantiate_kernel("g4_" #op #tname, ternary_g_nd, type, op, 4) \
|
||||
instantiate_kernel("g5_" #op #tname, ternary_g_nd, type, op, 5)
|
||||
|
||||
#define instantiate_ternary_g(name, type, op) \
|
||||
template [[host_name("g_" name)]] [[kernel]] void ternary_g<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const int* shape, \
|
||||
constant const size_t* a_strides, \
|
||||
constant const size_t* b_strides, \
|
||||
constant const size_t* c_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
#define instantiate_ternary_types(op) \
|
||||
instantiate_ternary_all(op, bool_, bool) \
|
||||
instantiate_ternary_all(op, uint8, uint8_t) \
|
||||
instantiate_ternary_all(op, uint16, uint16_t) \
|
||||
instantiate_ternary_all(op, uint32, uint32_t) \
|
||||
instantiate_ternary_all(op, uint64, uint64_t) \
|
||||
instantiate_ternary_all(op, int8, int8_t) \
|
||||
instantiate_ternary_all(op, int16, int16_t) \
|
||||
instantiate_ternary_all(op, int32, int32_t) \
|
||||
instantiate_ternary_all(op, int64, int64_t) \
|
||||
instantiate_ternary_all(op, float16, half) \
|
||||
instantiate_ternary_all(op, float32, float) \
|
||||
instantiate_ternary_all(op, bfloat16, bfloat16_t) \
|
||||
instantiate_ternary_all(op, complex64, complex64_t) // clang-format on
|
||||
|
||||
#define instantiate_ternary_g_dim(name, type, op, dims) \
|
||||
template [[host_name("g" #dims "_" name )]] [[kernel]] void \
|
||||
ternary_g_nd<type, op, dims>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const int shape[dims], \
|
||||
constant const size_t a_strides[dims], \
|
||||
constant const size_t b_strides[dims], \
|
||||
constant const size_t c_strides[dims], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_ternary_g_nd(name, type, op) \
|
||||
template [[host_name("g1_" name)]] [[kernel]] void \
|
||||
ternary_g_nd1<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const size_t& a_strides, \
|
||||
constant const size_t& b_strides, \
|
||||
constant const size_t& c_strides, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("g2_" name)]] [[kernel]] void \
|
||||
ternary_g_nd2<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const size_t a_strides[2], \
|
||||
constant const size_t b_strides[2], \
|
||||
constant const size_t c_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g3_" name)]] [[kernel]] void \
|
||||
ternary_g_nd3<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const size_t a_strides[3], \
|
||||
constant const size_t b_strides[3], \
|
||||
constant const size_t c_strides[3], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
instantiate_ternary_g_dim(name, type, op, 4) \
|
||||
instantiate_ternary_g_dim(name, type, op, 5)
|
||||
|
||||
#define instantiate_ternary_all(name, tname, type, op) \
|
||||
instantiate_ternary_v(#name #tname, type, op) \
|
||||
instantiate_ternary_g(#name #tname, type, op) \
|
||||
instantiate_ternary_g_nd(#name #tname, type, op)
|
||||
|
||||
#define instantiate_ternary_types(name, op) \
|
||||
instantiate_ternary_all(name, bool_, bool, op) \
|
||||
instantiate_ternary_all(name, uint8, uint8_t, op) \
|
||||
instantiate_ternary_all(name, uint16, uint16_t, op) \
|
||||
instantiate_ternary_all(name, uint32, uint32_t, op) \
|
||||
instantiate_ternary_all(name, uint64, uint64_t, op) \
|
||||
instantiate_ternary_all(name, int8, int8_t, op) \
|
||||
instantiate_ternary_all(name, int16, int16_t, op) \
|
||||
instantiate_ternary_all(name, int32, int32_t, op) \
|
||||
instantiate_ternary_all(name, int64, int64_t, op) \
|
||||
instantiate_ternary_all(name, float16, half, op) \
|
||||
instantiate_ternary_all(name, float32, float, op) \
|
||||
instantiate_ternary_all(name, bfloat16, bfloat16_t, op) \
|
||||
instantiate_ternary_all(name, complex64, complex64_t, op) // clang-format on
|
||||
|
||||
instantiate_ternary_types(select, Select)
|
||||
instantiate_ternary_types(Select)
|
||||
|
@ -5,83 +5,68 @@
|
||||
#include "mlx/backend/metal/kernels/unary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/unary.h"
|
||||
|
||||
#define instantiate_unary_v(name, type, op) \
|
||||
template [[host_name(name)]] [[kernel]] void unary_v<type, op>( \
|
||||
device const type* in, \
|
||||
device type* out, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
#define instantiate_unary_all(op, tname, type) \
|
||||
instantiate_kernel("v" #op #tname, unary_v, type, op) \
|
||||
instantiate_kernel("g" #op #tname, unary_g, type, op)
|
||||
|
||||
#define instantiate_unary_g(name, type, op) \
|
||||
template [[host_name(name)]] [[kernel]] void unary_g<type, op>( \
|
||||
device const type* in, \
|
||||
device type* out, \
|
||||
device const int* in_shape, \
|
||||
device const size_t* in_strides, \
|
||||
device const int& ndim, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
#define instantiate_unary_float(op) \
|
||||
instantiate_unary_all(op, float16, half) \
|
||||
instantiate_unary_all(op, float32, float) \
|
||||
instantiate_unary_all(op, bfloat16, bfloat16_t)
|
||||
|
||||
#define instantiate_unary_all(name, tname, type, op) \
|
||||
instantiate_unary_v("v" #name #tname, type, op) \
|
||||
instantiate_unary_g("g" #name #tname, type, op)
|
||||
#define instantiate_unary_types(op) \
|
||||
instantiate_unary_all(op, bool_, bool) \
|
||||
instantiate_unary_all(op, uint8, uint8_t) \
|
||||
instantiate_unary_all(op, uint16, uint16_t) \
|
||||
instantiate_unary_all(op, uint32, uint32_t) \
|
||||
instantiate_unary_all(op, uint64, uint64_t) \
|
||||
instantiate_unary_all(op, int8, int8_t) \
|
||||
instantiate_unary_all(op, int16, int16_t) \
|
||||
instantiate_unary_all(op, int32, int32_t) \
|
||||
instantiate_unary_all(op, int64, int64_t) \
|
||||
instantiate_unary_float(op)
|
||||
|
||||
#define instantiate_unary_float(name, op) \
|
||||
instantiate_unary_all(name, float16, half, op) \
|
||||
instantiate_unary_all(name, float32, float, op) \
|
||||
instantiate_unary_all(name, bfloat16, bfloat16_t, op)
|
||||
instantiate_unary_types(Abs)
|
||||
instantiate_unary_float(ArcCos)
|
||||
instantiate_unary_float(ArcCosh)
|
||||
instantiate_unary_float(ArcSin)
|
||||
instantiate_unary_float(ArcSinh)
|
||||
instantiate_unary_float(ArcTan)
|
||||
instantiate_unary_float(ArcTanh)
|
||||
instantiate_unary_types(Ceil)
|
||||
instantiate_unary_float(Cos)
|
||||
instantiate_unary_float(Cosh)
|
||||
instantiate_unary_float(Exp)
|
||||
instantiate_unary_float(Expm1)
|
||||
instantiate_unary_types(Floor)
|
||||
instantiate_unary_float(Log)
|
||||
instantiate_unary_float(Log2)
|
||||
instantiate_unary_float(Log10)
|
||||
instantiate_unary_float(Log1p)
|
||||
instantiate_unary_types(Negative)
|
||||
instantiate_unary_float(Sigmoid)
|
||||
instantiate_unary_float(Erf)
|
||||
instantiate_unary_float(ErfInv)
|
||||
instantiate_unary_types(Sign)
|
||||
instantiate_unary_float(Sin)
|
||||
instantiate_unary_float(Sinh)
|
||||
instantiate_unary_types(Square)
|
||||
instantiate_unary_float(Sqrt)
|
||||
instantiate_unary_float(Rsqrt)
|
||||
instantiate_unary_float(Tan)
|
||||
instantiate_unary_float(Tanh)
|
||||
instantiate_unary_float(Round)
|
||||
|
||||
#define instantiate_unary_types(name, op) \
|
||||
instantiate_unary_all(name, bool_, bool, op) \
|
||||
instantiate_unary_all(name, uint8, uint8_t, op) \
|
||||
instantiate_unary_all(name, uint16, uint16_t, op) \
|
||||
instantiate_unary_all(name, uint32, uint32_t, op) \
|
||||
instantiate_unary_all(name, uint64, uint64_t, op) \
|
||||
instantiate_unary_all(name, int8, int8_t, op) \
|
||||
instantiate_unary_all(name, int16, int16_t, op) \
|
||||
instantiate_unary_all(name, int32, int32_t, op) \
|
||||
instantiate_unary_all(name, int64, int64_t, op) \
|
||||
instantiate_unary_float(name, op)
|
||||
instantiate_unary_all(Abs, complex64, complex64_t)
|
||||
instantiate_unary_all(Conjugate, complex64, complex64_t)
|
||||
instantiate_unary_all(Cos, complex64, complex64_t)
|
||||
instantiate_unary_all(Cosh, complex64, complex64_t)
|
||||
instantiate_unary_all(Exp, complex64, complex64_t)
|
||||
instantiate_unary_all(Negative, complex64, complex64_t)
|
||||
instantiate_unary_all(Sin, complex64, complex64_t)
|
||||
instantiate_unary_all(Sinh, complex64, complex64_t)
|
||||
instantiate_unary_all(Tan, complex64, complex64_t)
|
||||
instantiate_unary_all(Tanh, complex64, complex64_t)
|
||||
instantiate_unary_all(Round, complex64, complex64_t)
|
||||
|
||||
instantiate_unary_types(abs, Abs)
|
||||
instantiate_unary_float(arccos, ArcCos)
|
||||
instantiate_unary_float(arccosh, ArcCosh)
|
||||
instantiate_unary_float(arcsin, ArcSin)
|
||||
instantiate_unary_float(arcsinh, ArcSinh)
|
||||
instantiate_unary_float(arctan, ArcTan)
|
||||
instantiate_unary_float(arctanh, ArcTanh)
|
||||
instantiate_unary_types(ceil, Ceil)
|
||||
instantiate_unary_float(cos, Cos)
|
||||
instantiate_unary_float(cosh, Cosh)
|
||||
instantiate_unary_float(exp, Exp)
|
||||
instantiate_unary_float(expm1, Expm1)
|
||||
instantiate_unary_types(floor, Floor)
|
||||
instantiate_unary_float(log, Log)
|
||||
instantiate_unary_float(log2, Log2)
|
||||
instantiate_unary_float(log10, Log10)
|
||||
instantiate_unary_float(log1p, Log1p)
|
||||
instantiate_unary_types(neg, Negative)
|
||||
instantiate_unary_float(sigmoid, Sigmoid)
|
||||
instantiate_unary_float(erf, Erf)
|
||||
instantiate_unary_float(erfinv, ErfInv)
|
||||
instantiate_unary_types(sign, Sign)
|
||||
instantiate_unary_float(sin, Sin)
|
||||
instantiate_unary_float(sinh, Sinh)
|
||||
instantiate_unary_types(square, Square)
|
||||
instantiate_unary_float(sqrt, Sqrt)
|
||||
instantiate_unary_float(rsqrt, Rsqrt)
|
||||
instantiate_unary_float(tan, Tan)
|
||||
instantiate_unary_float(tanh, Tanh)
|
||||
instantiate_unary_float(round, Round)
|
||||
|
||||
instantiate_unary_all(abs, complex64, complex64_t, Abs)
|
||||
instantiate_unary_all(conj, complex64, complex64_t, Conjugate)
|
||||
instantiate_unary_all(cos, complex64, complex64_t, Cos)
|
||||
instantiate_unary_all(cosh, complex64, complex64_t, Cosh)
|
||||
instantiate_unary_all(exp, complex64, complex64_t, Exp)
|
||||
instantiate_unary_all(neg, complex64, complex64_t, Negative)
|
||||
instantiate_unary_all(sin, complex64, complex64_t, Sin)
|
||||
instantiate_unary_all(sinh, complex64, complex64_t, Sinh)
|
||||
instantiate_unary_all(tan, complex64, complex64_t, Tan)
|
||||
instantiate_unary_all(tanh, complex64, complex64_t, Tanh)
|
||||
instantiate_unary_all(round, complex64, complex64_t, Round)
|
||||
|
||||
instantiate_unary_all(lnot, bool_, bool, LogicalNot) // clang-format on
|
||||
instantiate_unary_all(LogicalNot, bool_, bool) // clang-format on
|
||||
|
@ -15,30 +15,34 @@ MTL::ComputePipelineState* get_arange_kernel(
|
||||
MTL::ComputePipelineState* get_unary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array&) {
|
||||
Dtype,
|
||||
const std::string) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_binary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array&,
|
||||
const array&) {
|
||||
Dtype,
|
||||
Dtype,
|
||||
const std::string) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_binary_two_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array&,
|
||||
const array&) {
|
||||
Dtype,
|
||||
Dtype,
|
||||
const std::string) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_ternary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array&) {
|
||||
Dtype,
|
||||
const std::string) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
|
@ -49,7 +49,7 @@ void ternary_op_gpu_inplace(
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto kernel = get_ternary_kernel(d, kernel_name, out);
|
||||
auto kernel = get_ternary_kernel(d, kernel_name, out.dtype(), op);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
@ -122,7 +122,7 @@ void ternary_op_gpu(
|
||||
}
|
||||
|
||||
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
ternary_op_gpu(inputs, out, "select");
|
||||
ternary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -5,6 +5,11 @@
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#define UNARY_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
unary_op_gpu(inputs, out, get_primitive_string(this)); \
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void unary_op_gpu_inplace(
|
||||
@ -21,7 +26,7 @@ void unary_op_gpu_inplace(
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
std::string kernel_name = (contig ? "v" : "g") + op + type_to_name(out);
|
||||
auto kernel = get_unary_kernel(d, kernel_name, out);
|
||||
auto kernel = get_unary_kernel(d, kernel_name, out.dtype(), op);
|
||||
|
||||
size_t nthreads = contig ? in.data_size() : in.size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
@ -77,148 +82,57 @@ void unary_op_gpu(
|
||||
unary_op_gpu(inputs, out, op, s);
|
||||
}
|
||||
|
||||
void Abs::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op_gpu(inputs, out, "abs");
|
||||
}
|
||||
|
||||
void ArcCos::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op_gpu(inputs, out, "arccos");
|
||||
}
|
||||
|
||||
void ArcCosh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op_gpu(inputs, out, "arccosh");
|
||||
}
|
||||
|
||||
void ArcSin::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op_gpu(inputs, out, "arcsin");
|
||||
}
|
||||
|
||||
void ArcSinh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op_gpu(inputs, out, "arcsinh");
|
||||
}
|
||||
|
||||
void ArcTan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op_gpu(inputs, out, "arctan");
|
||||
}
|
||||
|
||||
void ArcTanh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op_gpu(inputs, out, "arctanh");
|
||||
}
|
||||
|
||||
void Conjugate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == complex64) {
|
||||
unary_op_gpu(inputs, out, "conj");
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[conjugate] conjugate must be called on complex input.");
|
||||
}
|
||||
}
|
||||
|
||||
void Cos::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op_gpu(inputs, out, "cos");
|
||||
}
|
||||
|
||||
void Cosh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op_gpu(inputs, out, "cosh");
|
||||
}
|
||||
|
||||
void Erf::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op_gpu(inputs, out, "erf");
|
||||
}
|
||||
|
||||
void ErfInv::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op_gpu(inputs, out, "erfinv");
|
||||
}
|
||||
|
||||
void Exp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op_gpu(inputs, out, "exp");
|
||||
}
|
||||
|
||||
void Expm1::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op_gpu(inputs, out, "expm1");
|
||||
}
|
||||
UNARY_GPU(Abs)
|
||||
UNARY_GPU(ArcCos)
|
||||
UNARY_GPU(ArcCosh)
|
||||
UNARY_GPU(ArcSin)
|
||||
UNARY_GPU(ArcSinh)
|
||||
UNARY_GPU(ArcTan)
|
||||
UNARY_GPU(ArcTanh)
|
||||
UNARY_GPU(Conjugate)
|
||||
UNARY_GPU(Cos)
|
||||
UNARY_GPU(Cosh)
|
||||
UNARY_GPU(Erf)
|
||||
UNARY_GPU(ErfInv)
|
||||
UNARY_GPU(Exp)
|
||||
UNARY_GPU(Expm1)
|
||||
UNARY_GPU(Log1p)
|
||||
UNARY_GPU(LogicalNot)
|
||||
UNARY_GPU(Floor)
|
||||
UNARY_GPU(Ceil)
|
||||
UNARY_GPU(Negative)
|
||||
UNARY_GPU(Sigmoid)
|
||||
UNARY_GPU(Sign)
|
||||
UNARY_GPU(Sin)
|
||||
UNARY_GPU(Sinh)
|
||||
UNARY_GPU(Square)
|
||||
UNARY_GPU(Sqrt)
|
||||
UNARY_GPU(Tan)
|
||||
UNARY_GPU(Tanh)
|
||||
|
||||
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
unary_op_gpu(inputs, out, "log");
|
||||
unary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
break;
|
||||
case Base::two:
|
||||
unary_op_gpu(inputs, out, "log2");
|
||||
unary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
break;
|
||||
case Base::ten:
|
||||
unary_op_gpu(inputs, out, "log10");
|
||||
unary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Log1p::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op_gpu(inputs, out, "log1p");
|
||||
}
|
||||
|
||||
void LogicalNot::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op_gpu(inputs, out, "lnot");
|
||||
}
|
||||
|
||||
void Floor::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op_gpu(inputs, out, "floor");
|
||||
}
|
||||
|
||||
void Ceil::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op_gpu(inputs, out, "ceil");
|
||||
}
|
||||
|
||||
void Negative::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op_gpu(inputs, out, "neg");
|
||||
}
|
||||
|
||||
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_op_gpu(inputs, out, "round");
|
||||
unary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Sigmoid::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op_gpu(inputs, out, "sigmoid");
|
||||
}
|
||||
|
||||
void Sign::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op_gpu(inputs, out, "sign");
|
||||
}
|
||||
|
||||
void Sin::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op_gpu(inputs, out, "sin");
|
||||
}
|
||||
|
||||
void Sinh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op_gpu(inputs, out, "sinh");
|
||||
}
|
||||
|
||||
void Square::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op_gpu(inputs, out, "square");
|
||||
}
|
||||
|
||||
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (recip_) {
|
||||
unary_op_gpu(inputs, out, "rsqrt");
|
||||
} else {
|
||||
unary_op_gpu(inputs, out, "sqrt");
|
||||
}
|
||||
}
|
||||
|
||||
void Tan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op_gpu(inputs, out, "tan");
|
||||
}
|
||||
|
||||
void Tanh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op_gpu(inputs, out, "tanh");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -141,6 +141,12 @@ int next_power_of_2(int n) {
|
||||
return pow(2, std::ceil(std::log2(n)));
|
||||
}
|
||||
|
||||
std::string get_primitive_string(Primitive* primitive) {
|
||||
std::ostringstream op_t;
|
||||
primitive->print(op_t);
|
||||
return op_t.str();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
||||
|
Loading…
Reference in New Issue
Block a user