mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
@@ -40,18 +40,21 @@ MTL::ComputePipelineState* get_arange_kernel(
|
||||
MTL::ComputePipelineState* get_unary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
auto in_t = get_type_string(in_type);
|
||||
auto out_t = get_type_string(out_type);
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::unary_ops() << metal::unary();
|
||||
kernel_source << get_template_definition(
|
||||
"v_" + lib_name, "unary_v", get_type_string(out_type), op);
|
||||
"v_" + lib_name, "unary_v", in_t, out_t, op);
|
||||
kernel_source << get_template_definition(
|
||||
"v2_" + lib_name, "unary_v2", get_type_string(out_type), op);
|
||||
"v2_" + lib_name, "unary_v2", in_t, out_t, op);
|
||||
kernel_source << get_template_definition(
|
||||
"gn4_" + lib_name, "unary_g", get_type_string(out_type), op, 4);
|
||||
"gn4_" + lib_name, "unary_g", in_t, out_t, op, 4);
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
|
@@ -15,6 +15,7 @@ MTL::ComputePipelineState* get_arange_kernel(
|
||||
MTL::ComputePipelineState* get_unary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op);
|
||||
|
||||
|
@@ -1,27 +1,27 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
template <typename T, typename Op>
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void unary_v(
|
||||
device const T* in,
|
||||
device T* out,
|
||||
device U* out,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
out[index] = Op()(in[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void unary_v2(
|
||||
device const T* in,
|
||||
device T* out,
|
||||
device U* out,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
out[offset] = Op()(in[offset]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op, int N = 1>
|
||||
template <typename T, typename U, typename Op, int N = 1>
|
||||
[[kernel]] void unary_g(
|
||||
device const T* in,
|
||||
device T* out,
|
||||
device U* out,
|
||||
constant const int* in_shape,
|
||||
constant const size_t* in_strides,
|
||||
device const int& ndim,
|
||||
|
@@ -5,26 +5,30 @@
|
||||
#include "mlx/backend/metal/kernels/unary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/unary.h"
|
||||
|
||||
#define instantiate_unary_all(op, tname, type) \
|
||||
instantiate_kernel("v_" #op #tname, unary_v, type, op) \
|
||||
instantiate_kernel("v2_" #op #tname, unary_v2, type, op) \
|
||||
instantiate_kernel("gn4_" #op #tname, unary_g, type, op, 4)
|
||||
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \
|
||||
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \
|
||||
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
|
||||
instantiate_kernel("gn4_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
|
||||
|
||||
|
||||
#define instantiate_unary_all_same(op, tname, type) \
|
||||
instantiate_unary_all(op, tname, tname, type, type)
|
||||
|
||||
#define instantiate_unary_float(op) \
|
||||
instantiate_unary_all(op, float16, half) \
|
||||
instantiate_unary_all(op, float32, float) \
|
||||
instantiate_unary_all(op, bfloat16, bfloat16_t)
|
||||
instantiate_unary_all_same(op, float16, half) \
|
||||
instantiate_unary_all_same(op, float32, float) \
|
||||
instantiate_unary_all_same(op, bfloat16, bfloat16_t)
|
||||
|
||||
#define instantiate_unary_types(op) \
|
||||
instantiate_unary_all(op, bool_, bool) \
|
||||
instantiate_unary_all(op, uint8, uint8_t) \
|
||||
instantiate_unary_all(op, uint16, uint16_t) \
|
||||
instantiate_unary_all(op, uint32, uint32_t) \
|
||||
instantiate_unary_all(op, uint64, uint64_t) \
|
||||
instantiate_unary_all(op, int8, int8_t) \
|
||||
instantiate_unary_all(op, int16, int16_t) \
|
||||
instantiate_unary_all(op, int32, int32_t) \
|
||||
instantiate_unary_all(op, int64, int64_t) \
|
||||
instantiate_unary_all_same(op, bool_, bool) \
|
||||
instantiate_unary_all_same(op, uint8, uint8_t) \
|
||||
instantiate_unary_all_same(op, uint16, uint16_t) \
|
||||
instantiate_unary_all_same(op, uint32, uint32_t) \
|
||||
instantiate_unary_all_same(op, uint64, uint64_t) \
|
||||
instantiate_unary_all_same(op, int8, int8_t) \
|
||||
instantiate_unary_all_same(op, int16, int16_t) \
|
||||
instantiate_unary_all_same(op, int32, int32_t) \
|
||||
instantiate_unary_all_same(op, int64, int64_t) \
|
||||
instantiate_unary_float(op)
|
||||
|
||||
instantiate_unary_types(Abs)
|
||||
@@ -58,17 +62,19 @@ instantiate_unary_float(Tan)
|
||||
instantiate_unary_float(Tanh)
|
||||
instantiate_unary_float(Round)
|
||||
|
||||
instantiate_unary_all(Abs, complex64, complex64_t)
|
||||
instantiate_unary_all(Conjugate, complex64, complex64_t)
|
||||
instantiate_unary_all(Cos, complex64, complex64_t)
|
||||
instantiate_unary_all(Cosh, complex64, complex64_t)
|
||||
instantiate_unary_all(Exp, complex64, complex64_t)
|
||||
instantiate_unary_all(Negative, complex64, complex64_t)
|
||||
instantiate_unary_all(Sign, complex64, complex64_t)
|
||||
instantiate_unary_all(Sin, complex64, complex64_t)
|
||||
instantiate_unary_all(Sinh, complex64, complex64_t)
|
||||
instantiate_unary_all(Tan, complex64, complex64_t)
|
||||
instantiate_unary_all(Tanh, complex64, complex64_t)
|
||||
instantiate_unary_all(Round, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Abs, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Conjugate, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Cos, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Cosh, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Exp, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Negative, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Sign, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Sin, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Sinh, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Tan, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Tanh, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Round, complex64, complex64_t)
|
||||
instantiate_unary_all(Real, complex64, float32, complex64_t, float)
|
||||
instantiate_unary_all(Imag, complex64, float32, complex64_t, float)
|
||||
|
||||
instantiate_unary_all(LogicalNot, bool_, bool) // clang-format on
|
||||
instantiate_unary_all_same(LogicalNot, bool_, bool) // clang-format on
|
||||
|
@@ -238,6 +238,13 @@ struct Floor {
|
||||
};
|
||||
};
|
||||
|
||||
struct Imag {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return x.imag;
|
||||
};
|
||||
};
|
||||
|
||||
struct Log {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
@@ -280,6 +287,13 @@ struct Negative {
|
||||
};
|
||||
};
|
||||
|
||||
struct Real {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return x.real;
|
||||
};
|
||||
};
|
||||
|
||||
struct Round {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
|
@@ -16,6 +16,7 @@ MTL::ComputePipelineState* get_unary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
Dtype,
|
||||
Dtype,
|
||||
const std::string) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
@@ -44,8 +44,8 @@ void unary_op_gpu_inplace(
|
||||
} else {
|
||||
kernel_name = (work_per_thread == 4 ? "gn4" : "g");
|
||||
}
|
||||
kernel_name += "_" + op + type_to_name(out);
|
||||
auto kernel = get_unary_kernel(d, kernel_name, out.dtype(), op);
|
||||
kernel_name += "_" + op + type_to_name(in) + type_to_name(out);
|
||||
auto kernel = get_unary_kernel(d, kernel_name, in.dtype(), out.dtype(), op);
|
||||
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(in.shape(), in.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
@@ -124,11 +124,13 @@ UNARY_GPU(Erf)
|
||||
UNARY_GPU(ErfInv)
|
||||
UNARY_GPU(Exp)
|
||||
UNARY_GPU(Expm1)
|
||||
UNARY_GPU(Imag)
|
||||
UNARY_GPU(Log1p)
|
||||
UNARY_GPU(LogicalNot)
|
||||
UNARY_GPU(Floor)
|
||||
UNARY_GPU(Ceil)
|
||||
UNARY_GPU(Negative)
|
||||
UNARY_GPU(Real)
|
||||
UNARY_GPU(Sigmoid)
|
||||
UNARY_GPU(Sign)
|
||||
UNARY_GPU(Sin)
|
||||
|
Reference in New Issue
Block a user