Refactor JIT for unary/binary/ternary ops (#1206)

* refactor unary/binary/ternary ops

* get_primitive_string util

---------
This commit is contained in:
Alex Barron 2024-06-12 14:22:12 -07:00 committed by GitHub
parent de2b9e7d0a
commit 934683088e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 379 additions and 935 deletions

View File

@ -6,6 +6,17 @@
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#define BINARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
binary_op_gpu(inputs, out, get_primitive_string(this)); \
}
#define BINARY_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
binary_op_gpu(inputs, outputs, get_primitive_string(this)); \
}
namespace mlx::core {
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
@ -61,7 +72,8 @@ void binary_op_gpu_inplace(
auto& d = metal::device(s.device);
auto kernel = get_binary_two_kernel(d, kernel_name, a, outputs[0]);
auto kernel =
get_binary_two_kernel(d, kernel_name, a.dtype(), outputs[0].dtype(), op);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
@ -188,7 +200,7 @@ void binary_op_gpu_inplace(
auto& d = metal::device(s.device);
auto kernel = get_binary_kernel(d, kernel_name, a, out);
auto kernel = get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
bool donate_a = a.data_shared_ptr() == nullptr;
@ -259,102 +271,44 @@ void binary_op_gpu(
binary_op_gpu(inputs, out, op, s);
}
void Add::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "add");
}
void ArcTan2::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "arctan2");
}
BINARY_GPU(Add)
BINARY_GPU(ArcTan2)
BINARY_GPU(Divide)
BINARY_GPU_MULTI(DivMod)
BINARY_GPU(Remainder)
BINARY_GPU(Equal)
BINARY_GPU(Greater)
BINARY_GPU(GreaterEqual)
BINARY_GPU(Less)
BINARY_GPU(LessEqual)
BINARY_GPU(LogicalAnd)
BINARY_GPU(LogicalOr)
BINARY_GPU(LogAddExp)
BINARY_GPU(Maximum)
BINARY_GPU(Minimum)
BINARY_GPU(Multiply)
BINARY_GPU(NotEqual)
BINARY_GPU(Power)
BINARY_GPU(Subtract)
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu(inputs, out, "bitwise_and");
binary_op_gpu(inputs, out, get_primitive_string(this));
break;
case BitwiseBinary::Or:
binary_op_gpu(inputs, out, "bitwise_or");
binary_op_gpu(inputs, out, get_primitive_string(this));
break;
case BitwiseBinary::Xor:
binary_op_gpu(inputs, out, "bitwise_xor");
binary_op_gpu(inputs, out, get_primitive_string(this));
break;
case BitwiseBinary::LeftShift:
binary_op_gpu(inputs, out, "left_shift");
binary_op_gpu(inputs, out, get_primitive_string(this));
break;
case BitwiseBinary::RightShift:
binary_op_gpu(inputs, out, "right_shift");
binary_op_gpu(inputs, out, get_primitive_string(this));
break;
}
}
void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "div");
}
void DivMod::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
binary_op_gpu(inputs, outputs, "divmod");
}
void Remainder::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "rem");
}
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, equal_nan_ ? "naneq" : "eq");
}
void Greater::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "ge");
}
void GreaterEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "geq");
}
void Less::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "le");
}
void LessEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "leq");
}
void LogicalAnd::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "land");
}
void LogicalOr::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "lor");
}
void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "lae");
}
void Maximum::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "max");
}
void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "min");
}
void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "mul");
}
void NotEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "neq");
}
void Power::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "pow");
}
void Subtract::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "sub");
}
} // namespace mlx::core

View File

@ -391,16 +391,16 @@ void multi_upload_bluestein_fft(
std::vector<int> rstrides(in.ndim(), 1);
rstarts[axis] = in.shape(axis) - back_offset;
rstrides[axis] = -1;
unary_op_gpu({in}, conj_temp, "conj", s);
unary_op_gpu({in}, conj_temp, "Conjugate", s);
slice_gpu(in, slice_temp, rstarts, rstrides, s);
concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s);
} else if (inverse) {
unary_op_gpu({in}, temp, "conj", s);
unary_op_gpu({in}, temp, "Conjugate", s);
} else {
temp.copy_shared_buffer(in);
}
binary_op_gpu({temp, w_k_broadcast}, temp1, "mul", s);
binary_op_gpu({temp, w_k_broadcast}, temp1, "Multiply", s);
std::vector<std::pair<int, int>> pads;
auto padded_shape = out.shape();
@ -419,7 +419,7 @@ void multi_upload_bluestein_fft(
/*inplace=*/false,
s);
binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "mul", s);
binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "Multiply", s);
fft_op(
pad_temp,
@ -437,7 +437,7 @@ void multi_upload_bluestein_fft(
starts[axis] = plan.bluestein_n - offset - n;
slice_gpu(pad_temp1, temp, starts, strides, s);
binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "mul", s);
binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "Multiply", s);
if (real && !inverse) {
std::vector<int> rstarts(in.ndim(), 0);
@ -451,11 +451,11 @@ void multi_upload_bluestein_fft(
copies.push_back(inv_n);
copy_gpu(temp1, temp_float, CopyType::General, s);
binary_op_gpu({temp_float, inv_n}, out, "mul", s);
binary_op_gpu({temp_float, inv_n}, out, "Multiply", s);
} else if (inverse) {
auto inv_n = array({1.0f / n}, {1}, complex64);
unary_op_gpu({temp1}, temp, "conj", s);
binary_op_gpu({temp, inv_n}, out, "mul", s);
unary_op_gpu({temp1}, temp, "Conjugate", s);
binary_op_gpu({temp, inv_n}, out, "Multiply", s);
copies.push_back(inv_n);
} else {
out.copy_shared_buffer(temp1);

View File

@ -1,87 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view binary_kernels = R"(
template [[host_name("ss{0}")]] [[kernel]]
void binary_ss<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
uint index [[thread_position_in_grid]]);
template [[host_name("vs{0}")]] [[kernel]]
void binary_vs<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
uint index [[thread_position_in_grid]]);
template [[host_name("sv{0}")]] [[kernel]]
void binary_sv<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
uint index [[thread_position_in_grid]]);
template [[host_name("vv{0}")]] [[kernel]]
void binary_vv<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
uint index [[thread_position_in_grid]]);
template [[host_name("g4{0}")]] [[kernel]] void
binary_g_nd<{1}, {2}, {3}, 4>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const int shape[4],
constant const size_t a_strides[4],
constant const size_t b_strides[4],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g5{0}")]] [[kernel]] void
binary_g_nd<{1}, {2}, {3}, 5>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const int shape[5],
constant const size_t a_strides[5],
constant const size_t b_strides[5],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g1{0}")]] [[kernel]] void
binary_g_nd1<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]);
template [[host_name("g2{0}")]] [[kernel]] void
binary_g_nd2<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]);
template [[host_name("g3{0}")]] [[kernel]] void
binary_g_nd3<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gn{0}")]] [[kernel]]
void binary_g<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
)";

View File

@ -1,98 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view binary_two_kernels = R"(
template [[host_name("ss{0}")]] [[kernel]]
void binary_ss<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("vs{0}")]] [[kernel]]
void binary_vs<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("sv{0}")]] [[kernel]]
void binary_sv<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("vv{0}")]] [[kernel]]
void binary_vv<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("g4{0}")]] [[kernel]] void
binary_g_nd<{1}, {2}, {3}, 4>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const int shape[4],
constant const size_t a_strides[4],
constant const size_t b_strides[4],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g5{0}")]] [[kernel]] void
binary_g_nd<{1}, {2}, {3}, 5>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const int shape[5],
constant const size_t a_strides[5],
constant const size_t b_strides[5],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g1{0}")]] [[kernel]] void
binary_g_nd1<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]);
template [[host_name("g2{0}")]] [[kernel]] void
binary_g_nd2<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]);
template [[host_name("g3{0}")]] [[kernel]] void
binary_g_nd3<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gn{0}")]] [[kernel]]
void binary_g<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
)";

View File

@ -1,80 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view ternary_kernels = R"(
template [[host_name("v_{0}")]] [[kernel]] void ternary_v<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("g_{0}")]] [[kernel]] void ternary_g<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const size_t* c_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g1_{0}")]] [[kernel]] void
ternary_g_nd1<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const size_t& a_strides,
constant const size_t& b_strides,
constant const size_t& c_strides,
uint index [[thread_position_in_grid]]);
template [[host_name("g2_{0}")]] [[kernel]] void
ternary_g_nd2<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
constant const size_t c_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]);
template [[host_name("g3_{0}")]] [[kernel]] void
ternary_g_nd3<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
constant const size_t c_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g4_{0}")]] [[kernel]] void
ternary_g_nd<{1}, {2}, 4>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const int shape[4],
constant const size_t a_strides[4],
constant const size_t b_strides[4],
constant const size_t c_strides[4],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g5_{0}")]] [[kernel]] void
ternary_g_nd<{1}, {2}, 5>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const int shape[5],
constant const size_t a_strides[5],
constant const size_t b_strides[5],
constant const size_t c_strides[5],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
)";

View File

@ -1,16 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view unary_kernels = R"(
template [[host_name("v{0}")]] [[kernel]] void unary_v<{1}, {2}>(
device const {1}* in,
device {1}* out,
uint index [[thread_position_in_grid]]);
template [[host_name("g{0}")]] [[kernel]] void unary_g<{1}, {2}>(
device const {1}* in,
device {1}* out,
device const int* in_shape,
device const size_t* in_strides,
device const int& ndim,
uint index [[thread_position_in_grid]]);
)";

View File

@ -1,9 +1,8 @@
// Copyright © 2024 Apple Inc.
#include <map>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/binary.h"
#include "mlx/backend/metal/jit/binary_two.h"
#include "mlx/backend/metal/jit/copy.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/reduce.h"
@ -12,8 +11,6 @@
#include "mlx/backend/metal/jit/sort.h"
#include "mlx/backend/metal/jit/steel_conv.h"
#include "mlx/backend/metal/jit/steel_gemm.h"
#include "mlx/backend/metal/jit/ternary.h"
#include "mlx/backend/metal/jit/unary.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
@ -46,38 +43,76 @@ MTL::ComputePipelineState* get_arange_kernel(
MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
Dtype out_type,
const std::string op) {
std::string lib_name = kernel_name.substr(1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
auto u_def = get_template_definition(
"v" + lib_name, "unary_v", get_type_string(out_type), op);
auto g_def = get_template_definition(
"g" + lib_name, "unary_g", get_type_string(out_type), op);
kernel_source << metal::utils() << metal::unary_ops() << metal::unary()
<< fmt::format(
unary_kernels,
lib_name,
get_type_string(out.dtype()),
op_name(out));
<< u_def << g_def;
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
void add_binary_kernels(
const std::string lib_name,
Dtype in_type,
Dtype out_type,
const std::string op,
std::ostringstream& kernel_source) {
const std::map<std::string, std::string> kernel_types = {
{"ss", "binary_ss"},
{"vs", "binary_vs"},
{"sv", "binary_sv"},
{"vv", "binary_vv"},
{"g1", "binary_g_nd1"},
{"g2", "binary_g_nd2"},
{"g3", "binary_g_nd3"},
{"g4", "binary_g_nd"},
{"g5", "binary_g_nd"},
{"gn", "binary_g"},
};
for (auto [name, func] : kernel_types) {
std::string template_def;
if (name == "g4" || name == "g5") {
int dim = std::stoi(name.substr(1));
template_def = get_template_definition(
name + lib_name,
func,
get_type_string(in_type),
get_type_string(out_type),
op,
dim);
} else {
template_def = get_template_definition(
name + lib_name,
func,
get_type_string(in_type),
get_type_string(out_type),
op);
}
kernel_source << template_def;
}
}
MTL::ComputePipelineState* get_binary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out) {
Dtype in_type,
Dtype out_type,
const std::string op) {
std::string lib_name = kernel_name.substr(2);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::binary_ops() << metal::binary()
<< fmt::format(
binary_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_name(out));
kernel_source << metal::utils() << metal::binary_ops() << metal::binary();
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
@ -86,20 +121,16 @@ MTL::ComputePipelineState* get_binary_kernel(
MTL::ComputePipelineState* get_binary_two_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out) {
Dtype in_type,
Dtype out_type,
const std::string op) {
std::string lib_name = kernel_name.substr(2);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::binary_ops()
<< metal::binary_two()
<< fmt::format(
binary_two_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_name(out));
<< metal::binary_two();
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
@ -108,17 +139,34 @@ MTL::ComputePipelineState* get_binary_two_kernel(
MTL::ComputePipelineState* get_ternary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
Dtype type,
const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary()
<< fmt::format(
ternary_kernels,
lib_name,
get_type_string(out.dtype()),
op_name(out));
const std::map<std::string, std::string> kernel_types = {
{"v", "ternary_v"},
{"g", "ternary_g"},
{"g1", "ternary_g_nd1"},
{"g2", "ternary_g_nd2"},
{"g3", "ternary_g_nd3"},
{"g4", "ternary_g_nd"},
{"g5", "ternary_g_nd"},
};
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary();
for (auto [name, func] : kernel_types) {
std::string template_def;
if (name == "g4" || name == "g5") {
int dim = std::stoi(name.substr(1));
template_def = get_template_definition(
name + "_" + lib_name, func, get_type_string(type), op, dim);
} else {
template_def = get_template_definition(
name + "_" + lib_name, func, get_type_string(type), op);
}
kernel_source << template_def;
}
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);

View File

@ -15,24 +15,28 @@ MTL::ComputePipelineState* get_arange_kernel(
MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out);
Dtype out_type,
const std::string op);
MTL::ComputePipelineState* get_binary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out);
Dtype in_type,
Dtype out_type,
const std::string op);
MTL::ComputePipelineState* get_binary_two_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out);
Dtype in_type,
Dtype out_type,
const std::string op);
MTL::ComputePipelineState* get_ternary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out);
Dtype type,
const std::string op);
MTL::ComputePipelineState* get_copy_kernel(
metal::Device& d,

View File

@ -4,148 +4,91 @@
#include <metal_math>
// clang-format off
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary.h"
#define instantiate_binary(name, itype, otype, op, bopt) \
template \
[[host_name(name)]] [[kernel]] void binary_##bopt<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
uint index [[thread_position_in_grid]]);
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \
instantiate_kernel("g4" #op #tname, binary_g_nd, itype, otype, op, 4) \
instantiate_kernel("g5" #op #tname, binary_g_nd, itype, otype, op, 5)
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
template [[host_name("g" #dims name)]] [[kernel]] void \
binary_g_nd<itype, otype, op, dims>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const int shape[dims], \
constant const size_t a_strides[dims], \
constant const size_t b_strides[dims], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
#define instantiate_binary_integer(op) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
instantiate_binary_all(op, uint64, uint64_t, uint64_t) \
instantiate_binary_all(op, int8, int8_t, int8_t) \
instantiate_binary_all(op, int16, int16_t, int16_t) \
instantiate_binary_all(op, int32, int32_t, int32_t) \
instantiate_binary_all(op, int64, int64_t, int64_t)
#define instantiate_binary_g_nd(name, itype, otype, op) \
template [[host_name("g1" name)]] [[kernel]] void \
binary_g_nd1<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t& a_stride, \
constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \
template [[host_name("g2" name)]] [[kernel]] void \
binary_g_nd2<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t a_strides[2], \
constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name("g3" name)]] [[kernel]] void \
binary_g_nd3<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t a_strides[3], \
constant const size_t b_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_binary_g_dim(name, itype, otype, op, 4) \
instantiate_binary_g_dim(name, itype, otype, op, 5)
#define instantiate_binary_float(op) \
instantiate_binary_all(op, float16, half, half) \
instantiate_binary_all(op, float32, float, float) \
instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t)
#define instantiate_binary_g(name, itype, otype, op) \
template [[host_name("gn" name)]] [[kernel]] void binary_g<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
#define instantiate_binary_types(op) \
instantiate_binary_all(op, bool_, bool, bool) \
instantiate_binary_integer(op) \
instantiate_binary_all(op, complex64, complex64_t, complex64_t) \
instantiate_binary_float(op)
#define instantiate_binary_all(name, tname, itype, otype, op) \
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
instantiate_binary_g(#name #tname, itype, otype, op) \
instantiate_binary_g_nd(#name #tname, itype, otype, op)
#define instantiate_binary_types_bool(op) \
instantiate_binary_all(op, bool_, bool, bool) \
instantiate_binary_all(op, uint8, uint8_t, bool) \
instantiate_binary_all(op, uint16, uint16_t, bool) \
instantiate_binary_all(op, uint32, uint32_t, bool) \
instantiate_binary_all(op, uint64, uint64_t, bool) \
instantiate_binary_all(op, int8, int8_t, bool) \
instantiate_binary_all(op, int16, int16_t, bool) \
instantiate_binary_all(op, int32, int32_t, bool) \
instantiate_binary_all(op, int64, int64_t, bool) \
instantiate_binary_all(op, float16, half, bool) \
instantiate_binary_all(op, float32, float, bool) \
instantiate_binary_all(op, bfloat16, bfloat16_t, bool) \
instantiate_binary_all(op, complex64, complex64_t, bool)
#define instantiate_binary_integer(name, op) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
instantiate_binary_all(name, int64, int64_t, int64_t, op)
#define instantiate_binary_float(name, op) \
instantiate_binary_all(name, float16, half, half, op) \
instantiate_binary_all(name, float32, float, float, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
#define instantiate_binary_types(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_integer(name, op) \
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
instantiate_binary_float(name, op)
#define instantiate_binary_types_bool(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
instantiate_binary_all(name, uint16, uint16_t, bool, op) \
instantiate_binary_all(name, uint32, uint32_t, bool, op) \
instantiate_binary_all(name, uint64, uint64_t, bool, op) \
instantiate_binary_all(name, int8, int8_t, bool, op) \
instantiate_binary_all(name, int16, int16_t, bool, op) \
instantiate_binary_all(name, int32, int32_t, bool, op) \
instantiate_binary_all(name, int64, int64_t, bool, op) \
instantiate_binary_all(name, float16, half, bool, op) \
instantiate_binary_all(name, float32, float, bool, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \
instantiate_binary_all(name, complex64, complex64_t, bool, op)
instantiate_binary_types(add, Add)
instantiate_binary_types(div, Divide)
instantiate_binary_types_bool(eq, Equal)
instantiate_binary_types_bool(ge, Greater)
instantiate_binary_types_bool(geq, GreaterEqual)
instantiate_binary_types_bool(le, Less)
instantiate_binary_types_bool(leq, LessEqual)
instantiate_binary_types_bool(neq, NotEqual)
instantiate_binary_float(lae, LogAddExp)
instantiate_binary_types(max, Maximum)
instantiate_binary_types(min, Minimum)
instantiate_binary_types(mul, Multiply)
instantiate_binary_types(sub, Subtract)
instantiate_binary_types(pow, Power)
instantiate_binary_types(rem, Remainder)
instantiate_binary_float(arctan2, ArcTan2)
instantiate_binary_types(Add)
instantiate_binary_types(Divide)
instantiate_binary_types_bool(Equal)
instantiate_binary_types_bool(Greater)
instantiate_binary_types_bool(GreaterEqual)
instantiate_binary_types_bool(Less)
instantiate_binary_types_bool(LessEqual)
instantiate_binary_types_bool(NotEqual)
instantiate_binary_float(LogAddExp)
instantiate_binary_types(Maximum)
instantiate_binary_types(Minimum)
instantiate_binary_types(Multiply)
instantiate_binary_types(Subtract)
instantiate_binary_types(Power)
instantiate_binary_types(Remainder)
instantiate_binary_float(ArcTan2)
// NaNEqual only needed for floating point types with boolean output
instantiate_binary_all(naneq, float16, half, bool, NaNEqual)
instantiate_binary_all(naneq, float32, float, bool, NaNEqual)
instantiate_binary_all(naneq, bfloat16, bfloat16_t, bool, NaNEqual)
instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual)
instantiate_binary_all(NaNEqual, float16, half, bool)
instantiate_binary_all(NaNEqual, float32, float, bool)
instantiate_binary_all(NaNEqual, bfloat16, bfloat16_t, bool)
instantiate_binary_all(NaNEqual, complex64, complex64_t, bool)
instantiate_binary_all(lor, bool_, bool, bool, LogicalOr)
instantiate_binary_all(land, bool_, bool, bool, LogicalAnd)
instantiate_binary_all(LogicalOr, bool_, bool, bool)
instantiate_binary_all(LogicalAnd, bool_, bool, bool)
// Bitwise ops only need integer types and bool (except for l/r shift)
instantiate_binary_integer(bitwise_and, BitwiseAnd)
instantiate_binary_all(bitwise_and, bool_, bool, bool, BitwiseAnd)
instantiate_binary_integer(bitwise_or, BitwiseOr)
instantiate_binary_all(bitwise_or, bool_, bool, bool, BitwiseOr)
instantiate_binary_integer(bitwise_xor, BitwiseXor)
instantiate_binary_all(bitwise_xor, bool_, bool, bool, BitwiseXor)
instantiate_binary_integer(left_shift, LeftShift)
instantiate_binary_integer(right_shift, RightShift) // clang-format on
instantiate_binary_integer(BitwiseAnd)
instantiate_binary_all(BitwiseAnd, bool_, bool, bool)
instantiate_binary_integer(BitwiseOr)
instantiate_binary_all(BitwiseOr, bool_, bool, bool)
instantiate_binary_integer(BitwiseXor)
instantiate_binary_all(BitwiseXor, bool_, bool, bool)
instantiate_binary_integer(LeftShift)
instantiate_binary_integer(RightShift) // clang-format on

View File

@ -7,99 +7,34 @@
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary_two.h"
#define instantiate_binary(name, itype, otype, op, bopt) \
template [[host_name(name)]] [[kernel]] void \
binary_##bopt<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
uint index [[thread_position_in_grid]]);
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \
instantiate_kernel("g4" #op #tname, binary_g_nd, itype, otype, op, 4) \
instantiate_kernel("g5" #op #tname, binary_g_nd, itype, otype, op, 5)
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
template [[host_name("g" #dims name)]] [[kernel]] void \
binary_g_nd<itype, otype, op, dims>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const int shape[dims], \
constant const size_t a_strides[dims], \
constant const size_t b_strides[dims], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
#define instantiate_binary_float(op) \
instantiate_binary_all(op, float16, half, half) \
instantiate_binary_all(op, float32, float, float) \
instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t)
#define instantiate_binary_g_nd(name, itype, otype, op) \
template [[host_name("g1" name)]] [[kernel]] void \
binary_g_nd1<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const size_t& a_stride, \
constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \
template [[host_name("g2" name)]] [[kernel]] void \
binary_g_nd2<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const size_t a_strides[2], \
constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name("g3" name)]] [[kernel]] void \
binary_g_nd3<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const size_t a_strides[3], \
constant const size_t b_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_binary_g_dim(name, itype, otype, op, 4) \
instantiate_binary_g_dim(name, itype, otype, op, 5)
#define instantiate_binary_types(op) \
instantiate_binary_all(op, bool_, bool, bool) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
instantiate_binary_all(op, uint64, uint64_t, uint64_t) \
instantiate_binary_all(op, int8, int8_t, int8_t) \
instantiate_binary_all(op, int16, int16_t, int16_t) \
instantiate_binary_all(op, int32, int32_t, int32_t) \
instantiate_binary_all(op, int64, int64_t, int64_t) \
instantiate_binary_all(op, complex64, complex64_t, complex64_t) \
instantiate_binary_float(op)
#define instantiate_binary_g(name, itype, otype, op) \
template [[host_name("gn" name)]] [[kernel]] void \
binary_g<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
#define instantiate_binary_all(name, tname, itype, otype, op) \
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
instantiate_binary_g(#name #tname, itype, otype, op) \
instantiate_binary_g_nd(#name #tname, itype, otype, op)
#define instantiate_binary_float(name, op) \
instantiate_binary_all(name, float16, half, half, op) \
instantiate_binary_all(name, float32, float, float, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
#define instantiate_binary_types(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
instantiate_binary_all(name, int64, int64_t, int64_t, op) \
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
instantiate_binary_float(name, op)
instantiate_binary_types(divmod, DivMod) // clang-format on
instantiate_binary_types(DivMod) // clang-format on

View File

@ -9,96 +9,28 @@
#include "mlx/backend/metal/kernels/ternary_ops.h"
#include "mlx/backend/metal/kernels/ternary.h"
#define instantiate_ternary_v(name, type, op) \
template [[host_name("v_" name)]] [[kernel]] void ternary_v<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
uint index [[thread_position_in_grid]]);
#define instantiate_ternary_all(op, tname, type) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
instantiate_kernel("g_" #op #tname, ternary_g, type, op) \
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op) \
instantiate_kernel("g4_" #op #tname, ternary_g_nd, type, op, 4) \
instantiate_kernel("g5_" #op #tname, ternary_g_nd, type, op, 5)
#define instantiate_ternary_g(name, type, op) \
template [[host_name("g_" name)]] [[kernel]] void ternary_g<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const size_t* c_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
#define instantiate_ternary_types(op) \
instantiate_ternary_all(op, bool_, bool) \
instantiate_ternary_all(op, uint8, uint8_t) \
instantiate_ternary_all(op, uint16, uint16_t) \
instantiate_ternary_all(op, uint32, uint32_t) \
instantiate_ternary_all(op, uint64, uint64_t) \
instantiate_ternary_all(op, int8, int8_t) \
instantiate_ternary_all(op, int16, int16_t) \
instantiate_ternary_all(op, int32, int32_t) \
instantiate_ternary_all(op, int64, int64_t) \
instantiate_ternary_all(op, float16, half) \
instantiate_ternary_all(op, float32, float) \
instantiate_ternary_all(op, bfloat16, bfloat16_t) \
instantiate_ternary_all(op, complex64, complex64_t) // clang-format on
#define instantiate_ternary_g_dim(name, type, op, dims) \
template [[host_name("g" #dims "_" name )]] [[kernel]] void \
ternary_g_nd<type, op, dims>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const int shape[dims], \
constant const size_t a_strides[dims], \
constant const size_t b_strides[dims], \
constant const size_t c_strides[dims], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
#define instantiate_ternary_g_nd(name, type, op) \
template [[host_name("g1_" name)]] [[kernel]] void \
ternary_g_nd1<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const size_t& a_strides, \
constant const size_t& b_strides, \
constant const size_t& c_strides, \
uint index [[thread_position_in_grid]]); \
template [[host_name("g2_" name)]] [[kernel]] void \
ternary_g_nd2<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const size_t a_strides[2], \
constant const size_t b_strides[2], \
constant const size_t c_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name("g3_" name)]] [[kernel]] void \
ternary_g_nd3<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const size_t a_strides[3], \
constant const size_t b_strides[3], \
constant const size_t c_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_ternary_g_dim(name, type, op, 4) \
instantiate_ternary_g_dim(name, type, op, 5)
#define instantiate_ternary_all(name, tname, type, op) \
instantiate_ternary_v(#name #tname, type, op) \
instantiate_ternary_g(#name #tname, type, op) \
instantiate_ternary_g_nd(#name #tname, type, op)
#define instantiate_ternary_types(name, op) \
instantiate_ternary_all(name, bool_, bool, op) \
instantiate_ternary_all(name, uint8, uint8_t, op) \
instantiate_ternary_all(name, uint16, uint16_t, op) \
instantiate_ternary_all(name, uint32, uint32_t, op) \
instantiate_ternary_all(name, uint64, uint64_t, op) \
instantiate_ternary_all(name, int8, int8_t, op) \
instantiate_ternary_all(name, int16, int16_t, op) \
instantiate_ternary_all(name, int32, int32_t, op) \
instantiate_ternary_all(name, int64, int64_t, op) \
instantiate_ternary_all(name, float16, half, op) \
instantiate_ternary_all(name, float32, float, op) \
instantiate_ternary_all(name, bfloat16, bfloat16_t, op) \
instantiate_ternary_all(name, complex64, complex64_t, op) // clang-format on
instantiate_ternary_types(select, Select)
instantiate_ternary_types(Select)

View File

@ -5,83 +5,68 @@
#include "mlx/backend/metal/kernels/unary_ops.h"
#include "mlx/backend/metal/kernels/unary.h"
#define instantiate_unary_v(name, type, op) \
template [[host_name(name)]] [[kernel]] void unary_v<type, op>( \
device const type* in, \
device type* out, \
uint index [[thread_position_in_grid]]);
#define instantiate_unary_all(op, tname, type) \
instantiate_kernel("v" #op #tname, unary_v, type, op) \
instantiate_kernel("g" #op #tname, unary_g, type, op)
#define instantiate_unary_g(name, type, op) \
template [[host_name(name)]] [[kernel]] void unary_g<type, op>( \
device const type* in, \
device type* out, \
device const int* in_shape, \
device const size_t* in_strides, \
device const int& ndim, \
uint index [[thread_position_in_grid]]);
#define instantiate_unary_float(op) \
instantiate_unary_all(op, float16, half) \
instantiate_unary_all(op, float32, float) \
instantiate_unary_all(op, bfloat16, bfloat16_t)
#define instantiate_unary_all(name, tname, type, op) \
instantiate_unary_v("v" #name #tname, type, op) \
instantiate_unary_g("g" #name #tname, type, op)
#define instantiate_unary_types(op) \
instantiate_unary_all(op, bool_, bool) \
instantiate_unary_all(op, uint8, uint8_t) \
instantiate_unary_all(op, uint16, uint16_t) \
instantiate_unary_all(op, uint32, uint32_t) \
instantiate_unary_all(op, uint64, uint64_t) \
instantiate_unary_all(op, int8, int8_t) \
instantiate_unary_all(op, int16, int16_t) \
instantiate_unary_all(op, int32, int32_t) \
instantiate_unary_all(op, int64, int64_t) \
instantiate_unary_float(op)
#define instantiate_unary_float(name, op) \
instantiate_unary_all(name, float16, half, op) \
instantiate_unary_all(name, float32, float, op) \
instantiate_unary_all(name, bfloat16, bfloat16_t, op)
instantiate_unary_types(Abs)
instantiate_unary_float(ArcCos)
instantiate_unary_float(ArcCosh)
instantiate_unary_float(ArcSin)
instantiate_unary_float(ArcSinh)
instantiate_unary_float(ArcTan)
instantiate_unary_float(ArcTanh)
instantiate_unary_types(Ceil)
instantiate_unary_float(Cos)
instantiate_unary_float(Cosh)
instantiate_unary_float(Exp)
instantiate_unary_float(Expm1)
instantiate_unary_types(Floor)
instantiate_unary_float(Log)
instantiate_unary_float(Log2)
instantiate_unary_float(Log10)
instantiate_unary_float(Log1p)
instantiate_unary_types(Negative)
instantiate_unary_float(Sigmoid)
instantiate_unary_float(Erf)
instantiate_unary_float(ErfInv)
instantiate_unary_types(Sign)
instantiate_unary_float(Sin)
instantiate_unary_float(Sinh)
instantiate_unary_types(Square)
instantiate_unary_float(Sqrt)
instantiate_unary_float(Rsqrt)
instantiate_unary_float(Tan)
instantiate_unary_float(Tanh)
instantiate_unary_float(Round)
#define instantiate_unary_types(name, op) \
instantiate_unary_all(name, bool_, bool, op) \
instantiate_unary_all(name, uint8, uint8_t, op) \
instantiate_unary_all(name, uint16, uint16_t, op) \
instantiate_unary_all(name, uint32, uint32_t, op) \
instantiate_unary_all(name, uint64, uint64_t, op) \
instantiate_unary_all(name, int8, int8_t, op) \
instantiate_unary_all(name, int16, int16_t, op) \
instantiate_unary_all(name, int32, int32_t, op) \
instantiate_unary_all(name, int64, int64_t, op) \
instantiate_unary_float(name, op)
instantiate_unary_all(Abs, complex64, complex64_t)
instantiate_unary_all(Conjugate, complex64, complex64_t)
instantiate_unary_all(Cos, complex64, complex64_t)
instantiate_unary_all(Cosh, complex64, complex64_t)
instantiate_unary_all(Exp, complex64, complex64_t)
instantiate_unary_all(Negative, complex64, complex64_t)
instantiate_unary_all(Sin, complex64, complex64_t)
instantiate_unary_all(Sinh, complex64, complex64_t)
instantiate_unary_all(Tan, complex64, complex64_t)
instantiate_unary_all(Tanh, complex64, complex64_t)
instantiate_unary_all(Round, complex64, complex64_t)
instantiate_unary_types(abs, Abs)
instantiate_unary_float(arccos, ArcCos)
instantiate_unary_float(arccosh, ArcCosh)
instantiate_unary_float(arcsin, ArcSin)
instantiate_unary_float(arcsinh, ArcSinh)
instantiate_unary_float(arctan, ArcTan)
instantiate_unary_float(arctanh, ArcTanh)
instantiate_unary_types(ceil, Ceil)
instantiate_unary_float(cos, Cos)
instantiate_unary_float(cosh, Cosh)
instantiate_unary_float(exp, Exp)
instantiate_unary_float(expm1, Expm1)
instantiate_unary_types(floor, Floor)
instantiate_unary_float(log, Log)
instantiate_unary_float(log2, Log2)
instantiate_unary_float(log10, Log10)
instantiate_unary_float(log1p, Log1p)
instantiate_unary_types(neg, Negative)
instantiate_unary_float(sigmoid, Sigmoid)
instantiate_unary_float(erf, Erf)
instantiate_unary_float(erfinv, ErfInv)
instantiate_unary_types(sign, Sign)
instantiate_unary_float(sin, Sin)
instantiate_unary_float(sinh, Sinh)
instantiate_unary_types(square, Square)
instantiate_unary_float(sqrt, Sqrt)
instantiate_unary_float(rsqrt, Rsqrt)
instantiate_unary_float(tan, Tan)
instantiate_unary_float(tanh, Tanh)
instantiate_unary_float(round, Round)
instantiate_unary_all(abs, complex64, complex64_t, Abs)
instantiate_unary_all(conj, complex64, complex64_t, Conjugate)
instantiate_unary_all(cos, complex64, complex64_t, Cos)
instantiate_unary_all(cosh, complex64, complex64_t, Cosh)
instantiate_unary_all(exp, complex64, complex64_t, Exp)
instantiate_unary_all(neg, complex64, complex64_t, Negative)
instantiate_unary_all(sin, complex64, complex64_t, Sin)
instantiate_unary_all(sinh, complex64, complex64_t, Sinh)
instantiate_unary_all(tan, complex64, complex64_t, Tan)
instantiate_unary_all(tanh, complex64, complex64_t, Tanh)
instantiate_unary_all(round, complex64, complex64_t, Round)
instantiate_unary_all(lnot, bool_, bool, LogicalNot) // clang-format on
instantiate_unary_all(LogicalNot, bool_, bool) // clang-format on

View File

@ -15,30 +15,34 @@ MTL::ComputePipelineState* get_arange_kernel(
MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&) {
Dtype,
const std::string) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_binary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&,
const array&) {
Dtype,
Dtype,
const std::string) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_binary_two_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&,
const array&) {
Dtype,
Dtype,
const std::string) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_ternary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&) {
Dtype,
const std::string) {
return d.get_kernel(kernel_name);
}

View File

@ -49,7 +49,7 @@ void ternary_op_gpu_inplace(
auto& d = metal::device(s.device);
auto kernel = get_ternary_kernel(d, kernel_name, out);
auto kernel = get_ternary_kernel(d, kernel_name, out.dtype(), op);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
@ -122,7 +122,7 @@ void ternary_op_gpu(
}
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
ternary_op_gpu(inputs, out, "select");
ternary_op_gpu(inputs, out, get_primitive_string(this));
}
} // namespace mlx::core

View File

@ -5,6 +5,11 @@
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#define UNARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
unary_op_gpu(inputs, out, get_primitive_string(this)); \
}
namespace mlx::core {
void unary_op_gpu_inplace(
@ -21,7 +26,7 @@ void unary_op_gpu_inplace(
auto& d = metal::device(s.device);
std::string kernel_name = (contig ? "v" : "g") + op + type_to_name(out);
auto kernel = get_unary_kernel(d, kernel_name, out);
auto kernel = get_unary_kernel(d, kernel_name, out.dtype(), op);
size_t nthreads = contig ? in.data_size() : in.size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
@ -77,148 +82,57 @@ void unary_op_gpu(
unary_op_gpu(inputs, out, op, s);
}
void Abs::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op_gpu(inputs, out, "abs");
}
void ArcCos::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op_gpu(inputs, out, "arccos");
}
void ArcCosh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op_gpu(inputs, out, "arccosh");
}
void ArcSin::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op_gpu(inputs, out, "arcsin");
}
void ArcSinh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op_gpu(inputs, out, "arcsinh");
}
void ArcTan::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op_gpu(inputs, out, "arctan");
}
void ArcTanh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op_gpu(inputs, out, "arctanh");
}
void Conjugate::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == complex64) {
unary_op_gpu(inputs, out, "conj");
} else {
throw std::invalid_argument(
"[conjugate] conjugate must be called on complex input.");
}
}
void Cos::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op_gpu(inputs, out, "cos");
}
void Cosh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op_gpu(inputs, out, "cosh");
}
void Erf::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op_gpu(inputs, out, "erf");
}
void ErfInv::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op_gpu(inputs, out, "erfinv");
}
void Exp::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op_gpu(inputs, out, "exp");
}
void Expm1::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op_gpu(inputs, out, "expm1");
}
UNARY_GPU(Abs)
UNARY_GPU(ArcCos)
UNARY_GPU(ArcCosh)
UNARY_GPU(ArcSin)
UNARY_GPU(ArcSinh)
UNARY_GPU(ArcTan)
UNARY_GPU(ArcTanh)
UNARY_GPU(Conjugate)
UNARY_GPU(Cos)
UNARY_GPU(Cosh)
UNARY_GPU(Erf)
UNARY_GPU(ErfInv)
UNARY_GPU(Exp)
UNARY_GPU(Expm1)
UNARY_GPU(Log1p)
UNARY_GPU(LogicalNot)
UNARY_GPU(Floor)
UNARY_GPU(Ceil)
UNARY_GPU(Negative)
UNARY_GPU(Sigmoid)
UNARY_GPU(Sign)
UNARY_GPU(Sin)
UNARY_GPU(Sinh)
UNARY_GPU(Square)
UNARY_GPU(Sqrt)
UNARY_GPU(Tan)
UNARY_GPU(Tanh)
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
switch (base_) {
case Base::e:
unary_op_gpu(inputs, out, "log");
unary_op_gpu(inputs, out, get_primitive_string(this));
break;
case Base::two:
unary_op_gpu(inputs, out, "log2");
unary_op_gpu(inputs, out, get_primitive_string(this));
break;
case Base::ten:
unary_op_gpu(inputs, out, "log10");
unary_op_gpu(inputs, out, get_primitive_string(this));
break;
}
}
void Log1p::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op_gpu(inputs, out, "log1p");
}
void LogicalNot::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op_gpu(inputs, out, "lnot");
}
void Floor::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op_gpu(inputs, out, "floor");
}
void Ceil::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op_gpu(inputs, out, "ceil");
}
void Negative::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op_gpu(inputs, out, "neg");
}
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_op_gpu(inputs, out, "round");
unary_op_gpu(inputs, out, get_primitive_string(this));
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Sigmoid::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op_gpu(inputs, out, "sigmoid");
}
void Sign::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op_gpu(inputs, out, "sign");
}
void Sin::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op_gpu(inputs, out, "sin");
}
void Sinh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op_gpu(inputs, out, "sinh");
}
void Square::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op_gpu(inputs, out, "square");
}
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
if (recip_) {
unary_op_gpu(inputs, out, "rsqrt");
} else {
unary_op_gpu(inputs, out, "sqrt");
}
}
void Tan::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op_gpu(inputs, out, "tan");
}
void Tanh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op_gpu(inputs, out, "tanh");
}
} // namespace mlx::core

View File

@ -141,6 +141,12 @@ int next_power_of_2(int n) {
return pow(2, std::ceil(std::log2(n)));
}
std::string get_primitive_string(Primitive* primitive) {
std::ostringstream op_t;
primitive->print(op_t);
return op_t.str();
}
} // namespace
} // namespace mlx::core