Add Primitive::name and remove Primitive::print (#2365)

This commit is contained in:
Cheng
2025-07-15 06:06:35 +09:00
committed by GitHub
parent 5201df5030
commit d34f887abc
32 changed files with 307 additions and 340 deletions

View File

@@ -7,20 +7,20 @@
#define BINARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
binary_op_gpu(inputs, out, get_primitive_string(this)); \
binary_op_gpu(inputs, out, name()); \
}
#define BINARY_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
binary_op_gpu(inputs, outputs, get_primitive_string(this)); \
binary_op_gpu(inputs, outputs, name()); \
}
namespace mlx::core {
std::string get_kernel_name(
BinaryOpType bopt,
const std::string& op,
const char* op,
const array& a,
bool large,
int ndim,
@@ -65,7 +65,7 @@ std::string get_kernel_name(
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string& op,
const char* op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
@@ -165,7 +165,7 @@ void binary_op_gpu_inplace(
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string& op,
const char* op,
const Stream& s) {
assert(inputs.size() == 2);
auto& a = inputs[0];
@@ -179,7 +179,7 @@ void binary_op_gpu(
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string& op) {
const char* op) {
auto& s = outputs[0].primitive().stream();
binary_op_gpu(inputs, outputs, op, s);
}
@@ -187,7 +187,7 @@ void binary_op_gpu(
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string& op,
const char* op,
const Stream& s) {
std::vector<array> outputs = {out};
binary_op_gpu_inplace(inputs, outputs, op, s);
@@ -196,7 +196,7 @@ void binary_op_gpu_inplace(
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string& op,
const char* op,
const Stream& s) {
assert(inputs.size() == 2);
auto& a = inputs[0];
@@ -209,7 +209,7 @@ void binary_op_gpu(
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string& op) {
const char* op) {
auto& s = out.primitive().stream();
binary_op_gpu(inputs, out, op, s);
}
@@ -237,19 +237,19 @@ BINARY_GPU(Subtract)
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu(inputs, out, get_primitive_string(this));
binary_op_gpu(inputs, out, name());
break;
case BitwiseBinary::Or:
binary_op_gpu(inputs, out, get_primitive_string(this));
binary_op_gpu(inputs, out, name());
break;
case BitwiseBinary::Xor:
binary_op_gpu(inputs, out, get_primitive_string(this));
binary_op_gpu(inputs, out, name());
break;
case BitwiseBinary::LeftShift:
binary_op_gpu(inputs, out, get_primitive_string(this));
binary_op_gpu(inputs, out, name());
break;
case BitwiseBinary::RightShift:
binary_op_gpu(inputs, out, get_primitive_string(this));
binary_op_gpu(inputs, out, name());
break;
}
}

View File

@@ -9,25 +9,25 @@ namespace mlx::core {
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string& op,
const char* op,
const Stream& s);
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string& op,
const char* op,
const Stream& s);
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string& op,
const char* op,
const Stream& s);
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string& op,
const char* op,
const Stream& s);
} // namespace mlx::core

View File

@@ -212,9 +212,7 @@ inline void build_kernel(
get_type_string(x.dtype()),
namer.get_name(x.inputs()[0]));
} else {
std::ostringstream ss;
x.primitive().print(ss);
os += ss.str();
os += x.primitive().name();
os += "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) {
os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));

View File

@@ -8,12 +8,6 @@ using namespace fmt::literals;
namespace mlx::core {
std::string op_name(const array& arr) {
std::ostringstream op_t;
arr.primitive().print(op_t);
return op_t.str();
}
MTL::ComputePipelineState* get_arange_kernel(
metal::Device& d,
const std::string& kernel_name,
@@ -33,7 +27,7 @@ MTL::ComputePipelineState* get_unary_kernel(
const std::string& kernel_name,
Dtype in_type,
Dtype out_type,
const std::string op) {
const char* op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
auto in_t = get_type_string(in_type);
@@ -58,10 +52,10 @@ MTL::ComputePipelineState* get_unary_kernel(
}
void append_binary_kernels(
const std::string lib_name,
const std::string& lib_name,
Dtype in_type,
Dtype out_type,
const std::string op,
const char* op,
std::string& kernel_source) {
const std::array<std::pair<std::string, std::string>, 7> kernel_types = {{
{"ss", "binary_ss"},
@@ -112,7 +106,7 @@ MTL::ComputePipelineState* get_binary_kernel(
const std::string& kernel_name,
Dtype in_type,
Dtype out_type,
const std::string op) {
const char* op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source;
@@ -129,7 +123,7 @@ MTL::ComputePipelineState* get_binary_two_kernel(
const std::string& kernel_name,
Dtype in_type,
Dtype out_type,
const std::string op) {
const char* op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils();
@@ -144,7 +138,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype type,
const std::string op) {
const char* op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
auto t_str = get_type_string(type);

View File

@@ -19,27 +19,27 @@ MTL::ComputePipelineState* get_unary_kernel(
const std::string& kernel_name,
Dtype in_type,
Dtype out_type,
const std::string op);
const char* op);
MTL::ComputePipelineState* get_binary_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype in_type,
Dtype out_type,
const std::string op);
const char* op);
MTL::ComputePipelineState* get_binary_two_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype in_type,
Dtype out_type,
const std::string op);
const char* op);
MTL::ComputePipelineState* get_ternary_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype type,
const std::string op);
const char* op);
MTL::ComputePipelineState* get_copy_kernel(
metal::Device& d,
@@ -257,8 +257,10 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
// Create a GPU kernel template definition for JIT compilation
template <typename... Args>
std::string
get_template_definition(std::string name, std::string func, Args... args) {
std::string get_template_definition(
std::string_view name,
std::string_view func,
Args... args) {
std::ostringstream s;
s << func << "<";
bool first = true;

View File

@@ -18,7 +18,7 @@ MTL::ComputePipelineState* get_unary_kernel(
const std::string& kernel_name,
Dtype,
Dtype,
const std::string) {
const char*) {
return d.get_kernel(kernel_name);
}
@@ -27,7 +27,7 @@ MTL::ComputePipelineState* get_binary_kernel(
const std::string& kernel_name,
Dtype,
Dtype,
const std::string) {
const char*) {
return d.get_kernel(kernel_name);
}
@@ -36,7 +36,7 @@ MTL::ComputePipelineState* get_binary_two_kernel(
const std::string& kernel_name,
Dtype,
Dtype,
const std::string) {
const char*) {
return d.get_kernel(kernel_name);
}
@@ -44,7 +44,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype,
const std::string) {
const char*) {
return d.get_kernel(kernel_name);
}

View File

@@ -11,7 +11,7 @@ namespace mlx::core {
void ternary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string op,
const char* op,
const Stream& s) {
assert(inputs.size() == 3);
auto& a = inputs[0];
@@ -128,7 +128,7 @@ void ternary_op_gpu_inplace(
void ternary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op,
const char* op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
@@ -141,13 +141,13 @@ void ternary_op_gpu(
void ternary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op) {
const char* op) {
auto& s = out.primitive().stream();
ternary_op_gpu(inputs, out, op, s);
}
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
ternary_op_gpu(inputs, out, get_primitive_string(this));
ternary_op_gpu(inputs, out, name());
}
} // namespace mlx::core

View File

@@ -9,13 +9,13 @@ namespace mlx::core {
void ternary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op,
const char* op,
const Stream& s);
void ternary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string op,
const char* op,
const Stream& s);
} // namespace mlx::core

View File

@@ -8,7 +8,7 @@
#define UNARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
unary_op_gpu(inputs, out, get_primitive_string(this)); \
unary_op_gpu(inputs, out, name()); \
}
namespace mlx::core {
@@ -16,7 +16,7 @@ namespace mlx::core {
void unary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string op,
const char* op,
const Stream& s) {
auto& in = inputs[0];
bool contig = in.flags().contiguous;
@@ -98,7 +98,7 @@ void unary_op_gpu_inplace(
void unary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op,
const char* op,
const Stream& s) {
set_unary_output_data(inputs[0], out);
unary_op_gpu_inplace(inputs, out, op, s);
@@ -107,7 +107,7 @@ void unary_op_gpu(
void unary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op) {
const char* op) {
auto& s = out.primitive().stream();
unary_op_gpu(inputs, out, op, s);
}
@@ -146,13 +146,13 @@ UNARY_GPU(Tanh)
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
switch (base_) {
case Base::e:
unary_op_gpu(inputs, out, get_primitive_string(this));
unary_op_gpu(inputs, out, name());
break;
case Base::two:
unary_op_gpu(inputs, out, get_primitive_string(this));
unary_op_gpu(inputs, out, name());
break;
case Base::ten:
unary_op_gpu(inputs, out, get_primitive_string(this));
unary_op_gpu(inputs, out, name());
break;
}
}
@@ -161,7 +161,7 @@ void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_op_gpu(inputs, out, get_primitive_string(this));
unary_op_gpu(inputs, out, name());
} else {
// No-op integer types
out.copy_shared_buffer(in);

View File

@@ -9,13 +9,13 @@ namespace mlx::core {
void unary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op,
const char* op,
const Stream& s);
void unary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string op,
const char* op,
const Stream& s);
} // namespace mlx::core

View File

@@ -40,7 +40,7 @@ inline void debug_set_primitive_buffer_label(
if (auto cbuf_label = command_buffer->label(); cbuf_label) {
label << cbuf_label->utf8String();
}
primitive.print(label);
label << primitive.name();
command_buffer->setLabel(make_string(label));
#endif
}