mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-21 18:28:11 +08:00
Add Primitive::name and remove Primitive::print (#2365)
This commit is contained in:
@@ -7,20 +7,20 @@
|
||||
|
||||
#define BINARY_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this)); \
|
||||
binary_op_gpu(inputs, out, name()); \
|
||||
}
|
||||
|
||||
#define BINARY_GPU_MULTI(func) \
|
||||
void func::eval_gpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
binary_op_gpu(inputs, outputs, get_primitive_string(this)); \
|
||||
binary_op_gpu(inputs, outputs, name()); \
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::string get_kernel_name(
|
||||
BinaryOpType bopt,
|
||||
const std::string& op,
|
||||
const char* op,
|
||||
const array& a,
|
||||
bool large,
|
||||
int ndim,
|
||||
@@ -65,7 +65,7 @@ std::string get_kernel_name(
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string& op,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
@@ -165,7 +165,7 @@ void binary_op_gpu_inplace(
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string& op,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
@@ -179,7 +179,7 @@ void binary_op_gpu(
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string& op) {
|
||||
const char* op) {
|
||||
auto& s = outputs[0].primitive().stream();
|
||||
binary_op_gpu(inputs, outputs, op, s);
|
||||
}
|
||||
@@ -187,7 +187,7 @@ void binary_op_gpu(
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
std::vector<array> outputs = {out};
|
||||
binary_op_gpu_inplace(inputs, outputs, op, s);
|
||||
@@ -196,7 +196,7 @@ void binary_op_gpu_inplace(
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
@@ -209,7 +209,7 @@ void binary_op_gpu(
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op) {
|
||||
const char* op) {
|
||||
auto& s = out.primitive().stream();
|
||||
binary_op_gpu(inputs, out, op, s);
|
||||
}
|
||||
@@ -237,19 +237,19 @@ BINARY_GPU(Subtract)
|
||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
binary_op_gpu(inputs, out, name());
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
binary_op_gpu(inputs, out, name());
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
binary_op_gpu(inputs, out, name());
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
binary_op_gpu(inputs, out, name());
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
binary_op_gpu(inputs, out, name());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@@ -9,25 +9,25 @@ namespace mlx::core {
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string& op,
|
||||
const char* op,
|
||||
const Stream& s);
|
||||
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const char* op,
|
||||
const Stream& s);
|
||||
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string& op,
|
||||
const char* op,
|
||||
const Stream& s);
|
||||
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const char* op,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -212,9 +212,7 @@ inline void build_kernel(
|
||||
get_type_string(x.dtype()),
|
||||
namer.get_name(x.inputs()[0]));
|
||||
} else {
|
||||
std::ostringstream ss;
|
||||
x.primitive().print(ss);
|
||||
os += ss.str();
|
||||
os += x.primitive().name();
|
||||
os += "()(";
|
||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||
os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
|
||||
|
@@ -8,12 +8,6 @@ using namespace fmt::literals;
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::string op_name(const array& arr) {
|
||||
std::ostringstream op_t;
|
||||
arr.primitive().print(op_t);
|
||||
return op_t.str();
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_arange_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
@@ -33,7 +27,7 @@ MTL::ComputePipelineState* get_unary_kernel(
|
||||
const std::string& kernel_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op) {
|
||||
const char* op) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
auto in_t = get_type_string(in_type);
|
||||
@@ -58,10 +52,10 @@ MTL::ComputePipelineState* get_unary_kernel(
|
||||
}
|
||||
|
||||
void append_binary_kernels(
|
||||
const std::string lib_name,
|
||||
const std::string& lib_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op,
|
||||
const char* op,
|
||||
std::string& kernel_source) {
|
||||
const std::array<std::pair<std::string, std::string>, 7> kernel_types = {{
|
||||
{"ss", "binary_ss"},
|
||||
@@ -112,7 +106,7 @@ MTL::ComputePipelineState* get_binary_kernel(
|
||||
const std::string& kernel_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op) {
|
||||
const char* op) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source;
|
||||
@@ -129,7 +123,7 @@ MTL::ComputePipelineState* get_binary_two_kernel(
|
||||
const std::string& kernel_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op) {
|
||||
const char* op) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source = metal::utils();
|
||||
@@ -144,7 +138,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
Dtype type,
|
||||
const std::string op) {
|
||||
const char* op) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
auto t_str = get_type_string(type);
|
||||
|
@@ -19,27 +19,27 @@ MTL::ComputePipelineState* get_unary_kernel(
|
||||
const std::string& kernel_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op);
|
||||
const char* op);
|
||||
|
||||
MTL::ComputePipelineState* get_binary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op);
|
||||
const char* op);
|
||||
|
||||
MTL::ComputePipelineState* get_binary_two_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op);
|
||||
const char* op);
|
||||
|
||||
MTL::ComputePipelineState* get_ternary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
Dtype type,
|
||||
const std::string op);
|
||||
const char* op);
|
||||
|
||||
MTL::ComputePipelineState* get_copy_kernel(
|
||||
metal::Device& d,
|
||||
@@ -257,8 +257,10 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
|
||||
// Create a GPU kernel template definition for JIT compilation
|
||||
template <typename... Args>
|
||||
std::string
|
||||
get_template_definition(std::string name, std::string func, Args... args) {
|
||||
std::string get_template_definition(
|
||||
std::string_view name,
|
||||
std::string_view func,
|
||||
Args... args) {
|
||||
std::ostringstream s;
|
||||
s << func << "<";
|
||||
bool first = true;
|
||||
|
@@ -18,7 +18,7 @@ MTL::ComputePipelineState* get_unary_kernel(
|
||||
const std::string& kernel_name,
|
||||
Dtype,
|
||||
Dtype,
|
||||
const std::string) {
|
||||
const char*) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ MTL::ComputePipelineState* get_binary_kernel(
|
||||
const std::string& kernel_name,
|
||||
Dtype,
|
||||
Dtype,
|
||||
const std::string) {
|
||||
const char*) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ MTL::ComputePipelineState* get_binary_two_kernel(
|
||||
const std::string& kernel_name,
|
||||
Dtype,
|
||||
Dtype,
|
||||
const std::string) {
|
||||
const char*) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
Dtype,
|
||||
const std::string) {
|
||||
const char*) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
|
@@ -11,7 +11,7 @@ namespace mlx::core {
|
||||
void ternary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() == 3);
|
||||
auto& a = inputs[0];
|
||||
@@ -128,7 +128,7 @@ void ternary_op_gpu_inplace(
|
||||
void ternary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
@@ -141,13 +141,13 @@ void ternary_op_gpu(
|
||||
void ternary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op) {
|
||||
const char* op) {
|
||||
auto& s = out.primitive().stream();
|
||||
ternary_op_gpu(inputs, out, op, s);
|
||||
}
|
||||
|
||||
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
ternary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
ternary_op_gpu(inputs, out, name());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -9,13 +9,13 @@ namespace mlx::core {
|
||||
void ternary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const char* op,
|
||||
const Stream& s);
|
||||
|
||||
void ternary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const char* op,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -8,7 +8,7 @@
|
||||
|
||||
#define UNARY_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
unary_op_gpu(inputs, out, get_primitive_string(this)); \
|
||||
unary_op_gpu(inputs, out, name()); \
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -16,7 +16,7 @@ namespace mlx::core {
|
||||
void unary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
auto& in = inputs[0];
|
||||
bool contig = in.flags().contiguous;
|
||||
@@ -98,7 +98,7 @@ void unary_op_gpu_inplace(
|
||||
void unary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
set_unary_output_data(inputs[0], out);
|
||||
unary_op_gpu_inplace(inputs, out, op, s);
|
||||
@@ -107,7 +107,7 @@ void unary_op_gpu(
|
||||
void unary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op) {
|
||||
const char* op) {
|
||||
auto& s = out.primitive().stream();
|
||||
unary_op_gpu(inputs, out, op, s);
|
||||
}
|
||||
@@ -146,13 +146,13 @@ UNARY_GPU(Tanh)
|
||||
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
unary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
unary_op_gpu(inputs, out, name());
|
||||
break;
|
||||
case Base::two:
|
||||
unary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
unary_op_gpu(inputs, out, name());
|
||||
break;
|
||||
case Base::ten:
|
||||
unary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
unary_op_gpu(inputs, out, name());
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -161,7 +161,7 @@ void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
unary_op_gpu(inputs, out, name());
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
|
@@ -9,13 +9,13 @@ namespace mlx::core {
|
||||
void unary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const char* op,
|
||||
const Stream& s);
|
||||
|
||||
void unary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const char* op,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -40,7 +40,7 @@ inline void debug_set_primitive_buffer_label(
|
||||
if (auto cbuf_label = command_buffer->label(); cbuf_label) {
|
||||
label << cbuf_label->utf8String();
|
||||
}
|
||||
primitive.print(label);
|
||||
label << primitive.name();
|
||||
command_buffer->setLabel(make_string(label));
|
||||
#endif
|
||||
}
|
||||
|
Reference in New Issue
Block a user