Add Primitive::name and remove Primitive::print (#2365)

This commit is contained in:
Cheng 2025-07-15 06:06:35 +09:00 committed by GitHub
parent 5201df5030
commit d34f887abc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 307 additions and 340 deletions

View File

@ -138,13 +138,13 @@ more concrete:
* representing the vectorized computation and the axis which * representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension. * corresponds to the output vectorized dimension.
*/ */
virtual std::pair<std::vector<array>, std::vector<int>> vmap( std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) override; const std::vector<int>& axes) override;
/** Print the primitive. */ /** The name of primitive. */
void print(std::ostream& os) override { const char* name() const override {
os << "Axpby"; return "Axpby";
} }
/** Equivalence check **/ /** Equivalence check **/

View File

@ -74,9 +74,9 @@ class Axpby : public mx::Primitive {
const std::vector<mx::array>& inputs, const std::vector<mx::array>& inputs,
const std::vector<int>& axes) override; const std::vector<int>& axes) override;
/** Print the primitive. */ /** The name of primitive. */
void print(std::ostream& os) override { const char* name() const override {
os << "Axpby"; return "Axpby";
} }
/** Equivalence check **/ /** Equivalence check **/

View File

@ -3,16 +3,9 @@
#include <dlfcn.h> #include <dlfcn.h>
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {
std::string get_primitive_string(Primitive* primitive) {
std::ostringstream op_t;
primitive->print(op_t);
return op_t.str();
}
std::filesystem::path current_binary_dir() { std::filesystem::path current_binary_dir() {
static std::filesystem::path binary_dir = []() { static std::filesystem::path binary_dir = []() {
Dl_info info; Dl_info info;

View File

@ -10,8 +10,6 @@
namespace mlx::core { namespace mlx::core {
std::string get_primitive_string(Primitive* primitive);
// Return the directory that contains current shared library. // Return the directory that contains current shared library.
std::filesystem::path current_binary_dir(); std::filesystem::path current_binary_dir();

View File

@ -231,7 +231,7 @@ inline void build_kernel(
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_" os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
<< namer.get_name(x.inputs()[0]) << ");" << std::endl; << namer.get_name(x.inputs()[0]) << ");" << std::endl;
} else { } else {
x.primitive().print(os); os << x.primitive().name();
os << "()("; os << "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) { for (int i = 0; i < x.inputs().size() - 1; i++) {
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", "; os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";

View File

@ -177,7 +177,7 @@ template <typename Op>
void binary_op_gpu_inplace( void binary_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
std::string_view op, const char* op,
const Stream& s) { const Stream& s) {
assert(inputs.size() > 1); assert(inputs.size() > 1);
const auto& a = inputs[0]; const auto& a = inputs[0];
@ -291,7 +291,7 @@ template <typename Op>
void binary_op_gpu( void binary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
std::string_view op, const char* op,
const Stream& s) { const Stream& s) {
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
@ -300,11 +300,11 @@ void binary_op_gpu(
binary_op_gpu_inplace<Op>(inputs, out, op, s); binary_op_gpu_inplace<Op>(inputs, out, op, s);
} }
#define BINARY_GPU(func) \ #define BINARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \ void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \ nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = out.primitive().stream(); \ auto& s = out.primitive().stream(); \
binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \ binary_op_gpu<cu::func>(inputs, out, name(), s); \
} }
BINARY_GPU(Add) BINARY_GPU(Add)
@ -328,33 +328,31 @@ BINARY_GPU(Subtract)
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) { void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Equal::eval_gpu"); nvtx3::scoped_range r("Equal::eval_gpu");
auto& s = out.primitive().stream(); auto& s = out.primitive().stream();
auto op = get_primitive_string(this);
if (equal_nan_) { if (equal_nan_) {
binary_op_gpu<cu::NaNEqual>(inputs, out, op, s); binary_op_gpu<cu::NaNEqual>(inputs, out, name(), s);
} else { } else {
binary_op_gpu<cu::Equal>(inputs, out, op, s); binary_op_gpu<cu::Equal>(inputs, out, name(), s);
} }
} }
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) { void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("BitwiseBinary::eval_gpu"); nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
auto& s = out.primitive().stream(); auto& s = out.primitive().stream();
auto op = get_primitive_string(this);
switch (op_) { switch (op_) {
case BitwiseBinary::And: case BitwiseBinary::And:
binary_op_gpu<cu::BitwiseAnd>(inputs, out, op, s); binary_op_gpu<cu::BitwiseAnd>(inputs, out, name(), s);
break; break;
case BitwiseBinary::Or: case BitwiseBinary::Or:
binary_op_gpu<cu::BitwiseOr>(inputs, out, op, s); binary_op_gpu<cu::BitwiseOr>(inputs, out, name(), s);
break; break;
case BitwiseBinary::Xor: case BitwiseBinary::Xor:
binary_op_gpu<cu::BitwiseXor>(inputs, out, op, s); binary_op_gpu<cu::BitwiseXor>(inputs, out, name(), s);
break; break;
case BitwiseBinary::LeftShift: case BitwiseBinary::LeftShift:
binary_op_gpu<cu::LeftShift>(inputs, out, op, s); binary_op_gpu<cu::LeftShift>(inputs, out, name(), s);
break; break;
case BitwiseBinary::RightShift: case BitwiseBinary::RightShift:
binary_op_gpu<cu::RightShift>(inputs, out, op, s); binary_op_gpu<cu::RightShift>(inputs, out, name(), s);
break; break;
} }
} }

View File

@ -184,7 +184,7 @@ template <typename Op>
void binary_two_op_gpu_inplace( void binary_two_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
std::string_view op, const char* op,
const Stream& s) { const Stream& s) {
assert(inputs.size() > 1); assert(inputs.size() > 1);
const auto& a = inputs[0]; const auto& a = inputs[0];
@ -314,7 +314,7 @@ template <typename Op>
void binary_two_op_gpu( void binary_two_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
std::string_view op, const char* op,
const Stream& s) { const Stream& s) {
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
@ -329,7 +329,7 @@ void DivMod::eval_gpu(
std::vector<array>& outputs) { std::vector<array>& outputs) {
nvtx3::scoped_range r("DivMod::eval_gpu"); nvtx3::scoped_range r("DivMod::eval_gpu");
auto& s = outputs[0].primitive().stream(); auto& s = outputs[0].primitive().stream();
binary_two_op_gpu<cu::DivMod>(inputs, outputs, get_primitive_string(this), s); binary_two_op_gpu<cu::DivMod>(inputs, outputs, name(), s);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -106,9 +106,7 @@ struct FusedKernelBuilder {
value = fmt::format( value = fmt::format(
"static_cast<{}>(tmp_{})", type, namer.get_name(x.inputs()[0])); "static_cast<{}>(tmp_{})", type, namer.get_name(x.inputs()[0]));
} else { } else {
std::ostringstream ss; value = x.primitive().name();
x.primitive().print(ss);
value = ss.str();
value += "{}("; value += "{}(";
for (size_t i = 0; i < x.inputs().size() - 1; ++i) { for (size_t i = 0; i < x.inputs().size() - 1; ++i) {
value += fmt::format("tmp_{}, ", namer.get_name(x.inputs()[i])); value += fmt::format("tmp_{}, ", namer.get_name(x.inputs()[i]));

View File

@ -102,7 +102,7 @@ template <typename Op>
void unary_op_gpu_inplace( void unary_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const std::string& op, const char* op,
const Stream& s) { const Stream& s) {
auto& in = inputs[0]; auto& in = inputs[0];
if (in.size() == 0) { if (in.size() == 0) {
@ -178,17 +178,17 @@ template <typename Op>
void unary_op_gpu( void unary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const std::string& op, const char* op,
const Stream& s) { const Stream& s) {
set_unary_output_data(inputs[0], out); set_unary_output_data(inputs[0], out);
unary_op_gpu_inplace<Op>(inputs, out, op, s); unary_op_gpu_inplace<Op>(inputs, out, op, s);
} }
#define UNARY_GPU(func) \ #define UNARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \ void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \ nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = out.primitive().stream(); \ auto& s = out.primitive().stream(); \
unary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \ unary_op_gpu<cu::func>(inputs, out, name(), s); \
} }
UNARY_GPU(Abs) UNARY_GPU(Abs)
@ -224,16 +224,15 @@ UNARY_GPU(Tanh)
void Log::eval_gpu(const std::vector<array>& inputs, array& out) { void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Log::eval_gpu"); nvtx3::scoped_range r("Log::eval_gpu");
auto& s = out.primitive().stream(); auto& s = out.primitive().stream();
auto op = get_primitive_string(this);
switch (base_) { switch (base_) {
case Base::e: case Base::e:
unary_op_gpu<cu::Log>(inputs, out, op, s); unary_op_gpu<cu::Log>(inputs, out, name(), s);
break; break;
case Base::two: case Base::two:
unary_op_gpu<cu::Log2>(inputs, out, op, s); unary_op_gpu<cu::Log2>(inputs, out, name(), s);
break; break;
case Base::ten: case Base::ten:
unary_op_gpu<cu::Log10>(inputs, out, op, s); unary_op_gpu<cu::Log10>(inputs, out, name(), s);
break; break;
} }
} }
@ -244,7 +243,7 @@ void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
const auto& in = inputs[0]; const auto& in = inputs[0];
auto& s = out.primitive().stream(); auto& s = out.primitive().stream();
if (issubdtype(in.dtype(), inexact)) { if (issubdtype(in.dtype(), inexact)) {
unary_op_gpu<cu::Round>(inputs, out, get_primitive_string(this), s); unary_op_gpu<cu::Round>(inputs, out, name(), s);
} else { } else {
// No-op integer types // No-op integer types
out.copy_shared_buffer(in); out.copy_shared_buffer(in);

View File

@ -7,20 +7,20 @@
#define BINARY_GPU(func) \ #define BINARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \ void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
binary_op_gpu(inputs, out, get_primitive_string(this)); \ binary_op_gpu(inputs, out, name()); \
} }
#define BINARY_GPU_MULTI(func) \ #define BINARY_GPU_MULTI(func) \
void func::eval_gpu( \ void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \ const std::vector<array>& inputs, std::vector<array>& outputs) { \
binary_op_gpu(inputs, outputs, get_primitive_string(this)); \ binary_op_gpu(inputs, outputs, name()); \
} }
namespace mlx::core { namespace mlx::core {
std::string get_kernel_name( std::string get_kernel_name(
BinaryOpType bopt, BinaryOpType bopt,
const std::string& op, const char* op,
const array& a, const array& a,
bool large, bool large,
int ndim, int ndim,
@ -65,7 +65,7 @@ std::string get_kernel_name(
void binary_op_gpu_inplace( void binary_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const std::string& op, const char* op,
const Stream& s) { const Stream& s) {
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
@ -165,7 +165,7 @@ void binary_op_gpu_inplace(
void binary_op_gpu( void binary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const std::string& op, const char* op,
const Stream& s) { const Stream& s) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
@ -179,7 +179,7 @@ void binary_op_gpu(
void binary_op_gpu( void binary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const std::string& op) { const char* op) {
auto& s = outputs[0].primitive().stream(); auto& s = outputs[0].primitive().stream();
binary_op_gpu(inputs, outputs, op, s); binary_op_gpu(inputs, outputs, op, s);
} }
@ -187,7 +187,7 @@ void binary_op_gpu(
void binary_op_gpu_inplace( void binary_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const std::string& op, const char* op,
const Stream& s) { const Stream& s) {
std::vector<array> outputs = {out}; std::vector<array> outputs = {out};
binary_op_gpu_inplace(inputs, outputs, op, s); binary_op_gpu_inplace(inputs, outputs, op, s);
@ -196,7 +196,7 @@ void binary_op_gpu_inplace(
void binary_op_gpu( void binary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const std::string& op, const char* op,
const Stream& s) { const Stream& s) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
@ -209,7 +209,7 @@ void binary_op_gpu(
void binary_op_gpu( void binary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const std::string& op) { const char* op) {
auto& s = out.primitive().stream(); auto& s = out.primitive().stream();
binary_op_gpu(inputs, out, op, s); binary_op_gpu(inputs, out, op, s);
} }
@ -237,19 +237,19 @@ BINARY_GPU(Subtract)
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) { void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
switch (op_) { switch (op_) {
case BitwiseBinary::And: case BitwiseBinary::And:
binary_op_gpu(inputs, out, get_primitive_string(this)); binary_op_gpu(inputs, out, name());
break; break;
case BitwiseBinary::Or: case BitwiseBinary::Or:
binary_op_gpu(inputs, out, get_primitive_string(this)); binary_op_gpu(inputs, out, name());
break; break;
case BitwiseBinary::Xor: case BitwiseBinary::Xor:
binary_op_gpu(inputs, out, get_primitive_string(this)); binary_op_gpu(inputs, out, name());
break; break;
case BitwiseBinary::LeftShift: case BitwiseBinary::LeftShift:
binary_op_gpu(inputs, out, get_primitive_string(this)); binary_op_gpu(inputs, out, name());
break; break;
case BitwiseBinary::RightShift: case BitwiseBinary::RightShift:
binary_op_gpu(inputs, out, get_primitive_string(this)); binary_op_gpu(inputs, out, name());
break; break;
} }
} }

View File

@ -9,25 +9,25 @@ namespace mlx::core {
void binary_op_gpu( void binary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const std::string& op, const char* op,
const Stream& s); const Stream& s);
void binary_op_gpu( void binary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const std::string& op, const char* op,
const Stream& s); const Stream& s);
void binary_op_gpu_inplace( void binary_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const std::string& op, const char* op,
const Stream& s); const Stream& s);
void binary_op_gpu_inplace( void binary_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const std::string& op, const char* op,
const Stream& s); const Stream& s);
} // namespace mlx::core } // namespace mlx::core

View File

@ -212,9 +212,7 @@ inline void build_kernel(
get_type_string(x.dtype()), get_type_string(x.dtype()),
namer.get_name(x.inputs()[0])); namer.get_name(x.inputs()[0]));
} else { } else {
std::ostringstream ss; os += x.primitive().name();
x.primitive().print(ss);
os += ss.str();
os += "()("; os += "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) { for (int i = 0; i < x.inputs().size() - 1; i++) {
os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i])); os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));

View File

@ -8,12 +8,6 @@ using namespace fmt::literals;
namespace mlx::core { namespace mlx::core {
std::string op_name(const array& arr) {
std::ostringstream op_t;
arr.primitive().print(op_t);
return op_t.str();
}
MTL::ComputePipelineState* get_arange_kernel( MTL::ComputePipelineState* get_arange_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
@ -33,7 +27,7 @@ MTL::ComputePipelineState* get_unary_kernel(
const std::string& kernel_name, const std::string& kernel_name,
Dtype in_type, Dtype in_type,
Dtype out_type, Dtype out_type,
const std::string op) { const char* op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() { auto lib = d.get_library(lib_name, [&]() {
auto in_t = get_type_string(in_type); auto in_t = get_type_string(in_type);
@ -58,10 +52,10 @@ MTL::ComputePipelineState* get_unary_kernel(
} }
void append_binary_kernels( void append_binary_kernels(
const std::string lib_name, const std::string& lib_name,
Dtype in_type, Dtype in_type,
Dtype out_type, Dtype out_type,
const std::string op, const char* op,
std::string& kernel_source) { std::string& kernel_source) {
const std::array<std::pair<std::string, std::string>, 7> kernel_types = {{ const std::array<std::pair<std::string, std::string>, 7> kernel_types = {{
{"ss", "binary_ss"}, {"ss", "binary_ss"},
@ -112,7 +106,7 @@ MTL::ComputePipelineState* get_binary_kernel(
const std::string& kernel_name, const std::string& kernel_name,
Dtype in_type, Dtype in_type,
Dtype out_type, Dtype out_type,
const std::string op) { const char* op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() { auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source; std::string kernel_source;
@ -129,7 +123,7 @@ MTL::ComputePipelineState* get_binary_two_kernel(
const std::string& kernel_name, const std::string& kernel_name,
Dtype in_type, Dtype in_type,
Dtype out_type, Dtype out_type,
const std::string op) { const char* op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() { auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils(); std::string kernel_source = metal::utils();
@ -144,7 +138,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
Dtype type, Dtype type,
const std::string op) { const char* op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() { auto lib = d.get_library(lib_name, [&]() {
auto t_str = get_type_string(type); auto t_str = get_type_string(type);

View File

@ -19,27 +19,27 @@ MTL::ComputePipelineState* get_unary_kernel(
const std::string& kernel_name, const std::string& kernel_name,
Dtype in_type, Dtype in_type,
Dtype out_type, Dtype out_type,
const std::string op); const char* op);
MTL::ComputePipelineState* get_binary_kernel( MTL::ComputePipelineState* get_binary_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
Dtype in_type, Dtype in_type,
Dtype out_type, Dtype out_type,
const std::string op); const char* op);
MTL::ComputePipelineState* get_binary_two_kernel( MTL::ComputePipelineState* get_binary_two_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
Dtype in_type, Dtype in_type,
Dtype out_type, Dtype out_type,
const std::string op); const char* op);
MTL::ComputePipelineState* get_ternary_kernel( MTL::ComputePipelineState* get_ternary_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
Dtype type, Dtype type,
const std::string op); const char* op);
MTL::ComputePipelineState* get_copy_kernel( MTL::ComputePipelineState* get_copy_kernel(
metal::Device& d, metal::Device& d,
@ -257,8 +257,10 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
// Create a GPU kernel template definition for JIT compilation // Create a GPU kernel template definition for JIT compilation
template <typename... Args> template <typename... Args>
std::string std::string get_template_definition(
get_template_definition(std::string name, std::string func, Args... args) { std::string_view name,
std::string_view func,
Args... args) {
std::ostringstream s; std::ostringstream s;
s << func << "<"; s << func << "<";
bool first = true; bool first = true;

View File

@ -18,7 +18,7 @@ MTL::ComputePipelineState* get_unary_kernel(
const std::string& kernel_name, const std::string& kernel_name,
Dtype, Dtype,
Dtype, Dtype,
const std::string) { const char*) {
return d.get_kernel(kernel_name); return d.get_kernel(kernel_name);
} }
@ -27,7 +27,7 @@ MTL::ComputePipelineState* get_binary_kernel(
const std::string& kernel_name, const std::string& kernel_name,
Dtype, Dtype,
Dtype, Dtype,
const std::string) { const char*) {
return d.get_kernel(kernel_name); return d.get_kernel(kernel_name);
} }
@ -36,7 +36,7 @@ MTL::ComputePipelineState* get_binary_two_kernel(
const std::string& kernel_name, const std::string& kernel_name,
Dtype, Dtype,
Dtype, Dtype,
const std::string) { const char*) {
return d.get_kernel(kernel_name); return d.get_kernel(kernel_name);
} }
@ -44,7 +44,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
Dtype, Dtype,
const std::string) { const char*) {
return d.get_kernel(kernel_name); return d.get_kernel(kernel_name);
} }

View File

@ -11,7 +11,7 @@ namespace mlx::core {
void ternary_op_gpu_inplace( void ternary_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const std::string op, const char* op,
const Stream& s) { const Stream& s) {
assert(inputs.size() == 3); assert(inputs.size() == 3);
auto& a = inputs[0]; auto& a = inputs[0];
@ -128,7 +128,7 @@ void ternary_op_gpu_inplace(
void ternary_op_gpu( void ternary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const std::string op, const char* op,
const Stream& s) { const Stream& s) {
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
@ -141,13 +141,13 @@ void ternary_op_gpu(
void ternary_op_gpu( void ternary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const std::string op) { const char* op) {
auto& s = out.primitive().stream(); auto& s = out.primitive().stream();
ternary_op_gpu(inputs, out, op, s); ternary_op_gpu(inputs, out, op, s);
} }
void Select::eval_gpu(const std::vector<array>& inputs, array& out) { void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
ternary_op_gpu(inputs, out, get_primitive_string(this)); ternary_op_gpu(inputs, out, name());
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -9,13 +9,13 @@ namespace mlx::core {
void ternary_op_gpu( void ternary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const std::string op, const char* op,
const Stream& s); const Stream& s);
void ternary_op_gpu_inplace( void ternary_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const std::string op, const char* op,
const Stream& s); const Stream& s);
} // namespace mlx::core } // namespace mlx::core

View File

@ -8,7 +8,7 @@
#define UNARY_GPU(func) \ #define UNARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \ void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
unary_op_gpu(inputs, out, get_primitive_string(this)); \ unary_op_gpu(inputs, out, name()); \
} }
namespace mlx::core { namespace mlx::core {
@ -16,7 +16,7 @@ namespace mlx::core {
void unary_op_gpu_inplace( void unary_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const std::string op, const char* op,
const Stream& s) { const Stream& s) {
auto& in = inputs[0]; auto& in = inputs[0];
bool contig = in.flags().contiguous; bool contig = in.flags().contiguous;
@ -98,7 +98,7 @@ void unary_op_gpu_inplace(
void unary_op_gpu( void unary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const std::string op, const char* op,
const Stream& s) { const Stream& s) {
set_unary_output_data(inputs[0], out); set_unary_output_data(inputs[0], out);
unary_op_gpu_inplace(inputs, out, op, s); unary_op_gpu_inplace(inputs, out, op, s);
@ -107,7 +107,7 @@ void unary_op_gpu(
void unary_op_gpu( void unary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const std::string op) { const char* op) {
auto& s = out.primitive().stream(); auto& s = out.primitive().stream();
unary_op_gpu(inputs, out, op, s); unary_op_gpu(inputs, out, op, s);
} }
@ -146,13 +146,13 @@ UNARY_GPU(Tanh)
void Log::eval_gpu(const std::vector<array>& inputs, array& out) { void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
switch (base_) { switch (base_) {
case Base::e: case Base::e:
unary_op_gpu(inputs, out, get_primitive_string(this)); unary_op_gpu(inputs, out, name());
break; break;
case Base::two: case Base::two:
unary_op_gpu(inputs, out, get_primitive_string(this)); unary_op_gpu(inputs, out, name());
break; break;
case Base::ten: case Base::ten:
unary_op_gpu(inputs, out, get_primitive_string(this)); unary_op_gpu(inputs, out, name());
break; break;
} }
} }
@ -161,7 +161,7 @@ void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) { if (issubdtype(in.dtype(), inexact)) {
unary_op_gpu(inputs, out, get_primitive_string(this)); unary_op_gpu(inputs, out, name());
} else { } else {
// No-op integer types // No-op integer types
out.copy_shared_buffer(in); out.copy_shared_buffer(in);

View File

@ -9,13 +9,13 @@ namespace mlx::core {
void unary_op_gpu( void unary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const std::string op, const char* op,
const Stream& s); const Stream& s);
void unary_op_gpu_inplace( void unary_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const std::string op, const char* op,
const Stream& s); const Stream& s);
} // namespace mlx::core } // namespace mlx::core

View File

@ -40,7 +40,7 @@ inline void debug_set_primitive_buffer_label(
if (auto cbuf_label = command_buffer->label(); cbuf_label) { if (auto cbuf_label = command_buffer->label(); cbuf_label) {
label << cbuf_label->utf8String(); label << cbuf_label->utf8String();
} }
primitive.print(label); label << primitive.name();
command_buffer->setLabel(make_string(label)); command_buffer->setLabel(make_string(label));
#endif #endif
} }

View File

@ -107,7 +107,7 @@ Compiled::Compiled(
// name and type of output // name and type of output
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize(); os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
// computation performed // computation performed
a.primitive().print(os); os << a.primitive().name();
// name of inputs to the function // name of inputs to the function
for (auto& inp : a.inputs()) { for (auto& inp : a.inputs()) {
os << namer.get_name(inp); os << namer.get_name(inp);
@ -170,11 +170,16 @@ bool Compiled::is_equivalent(const Primitive& other) const {
}); });
} }
void Compiled::print(std::ostream& os) { const char* Compiled::name() const {
os << "Compiled"; if (name_.empty()) {
for (auto& a : tape_) { std::ostringstream os;
a.primitive().print(os); os << "Compiled";
for (auto& a : tape_) {
os << a.primitive().name();
}
name_ = os.str();
} }
return name_.c_str();
} }
std::vector<Shape> Compiled::output_shapes(const std::vector<array>& inputs) { std::vector<Shape> Compiled::output_shapes(const std::vector<array>& inputs) {

View File

@ -45,27 +45,22 @@ class AllReduce : public DistPrimitive {
const std::vector<int>& argnums, const std::vector<int>& argnums,
const std::vector<array>& outputs) override; const std::vector<array>& outputs) override;
void print(std::ostream& os) override { const char* name() const override {
switch (reduce_type_) { switch (reduce_type_) {
case And: case And:
os << "And"; return "And AllReduce";
case Or: case Or:
os << "And"; return "Or AllReduce";
break;
case Sum: case Sum:
os << "Sum"; return "Sum AllReduce";
break;
case Prod: case Prod:
os << "Prod"; return "Prod AllReduce";
break;
case Min: case Min:
os << "Min"; return "Min AllReduce";
break;
case Max: case Max:
os << "Max"; return "Max AllReduce";
break;
} }
os << " AllReduce"; return "<unknwon AllReduce>";
} }
private: private:
@ -94,7 +89,7 @@ class AllGather : public DistPrimitive {
const std::vector<int>& argnums, const std::vector<int>& argnums,
const std::vector<array>& outputs) override; const std::vector<array>& outputs) override;
DEFINE_PRINT(AllGather); DEFINE_NAME(AllGather);
}; };
class Send : public DistPrimitive { class Send : public DistPrimitive {
@ -110,7 +105,7 @@ class Send : public DistPrimitive {
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) override; const std::vector<int>& axes) override;
DEFINE_PRINT(Send); DEFINE_NAME(Send);
private: private:
int dst_; int dst_;
@ -126,7 +121,7 @@ class Recv : public DistPrimitive {
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override; override;
DEFINE_PRINT(Recv); DEFINE_NAME(Recv);
private: private:
int src_; int src_;

View File

@ -354,9 +354,7 @@ struct PrimitiveFactory {
void save(Writer& os, const std::shared_ptr<Primitive>& p) { void save(Writer& os, const std::shared_ptr<Primitive>& p) {
serialize(os, p->stream()); serialize(os, p->stream());
std::ostringstream pout; std::string name = p->name();
p->print(pout);
auto name = pout.str();
name = name.substr(0, name.find(' ')); name = name.substr(0, name.find(' '));
if (auto it = name_remap.find(name); it != name_remap.end()) { if (auto it = name_remap.find(name); it != name_remap.end()) {
name = it->second; name = it->second;

View File

@ -58,7 +58,7 @@ class RMSNorm : public Custom {
const std::vector<int>& argnums, const std::vector<int>& argnums,
const std::vector<array>& outputs) override; const std::vector<array>& outputs) override;
DEFINE_PRINT(RMSNorm) DEFINE_NAME(RMSNorm)
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
@ -85,7 +85,7 @@ class RMSNormVJP : public Custom {
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override; override;
DEFINE_PRINT(RMSNormVJP) DEFINE_NAME(RMSNormVJP)
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
auto state() const { auto state() const {
return std::make_pair(nullptr, eps_); return std::make_pair(nullptr, eps_);
@ -118,7 +118,7 @@ class LayerNorm : public Custom {
const std::vector<int>& argnums, const std::vector<int>& argnums,
const std::vector<array>& outputs) override; const std::vector<array>& outputs) override;
DEFINE_PRINT(LayerNorm) DEFINE_NAME(LayerNorm)
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
auto state() const { auto state() const {
@ -144,7 +144,7 @@ class LayerNormVJP : public Custom {
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override; override;
DEFINE_PRINT(LayerNormVJP) DEFINE_NAME(LayerNormVJP)
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
auto state() const { auto state() const {
return std::make_pair(nullptr, eps_); return std::make_pair(nullptr, eps_);
@ -186,7 +186,7 @@ class RoPE : public Custom {
const std::vector<int>& argnums, const std::vector<int>& argnums,
const std::vector<array>& outputs) override; const std::vector<array>& outputs) override;
DEFINE_PRINT(RoPE) DEFINE_NAME(RoPE)
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
auto state() const { auto state() const {
@ -233,7 +233,7 @@ class ScaledDotProductAttention : public Custom {
void eval_gpu(const std::vector<array>& inputs, array& out); void eval_gpu(const std::vector<array>& inputs, array& out);
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
DEFINE_PRINT(ScaledDotProductAttention); DEFINE_NAME(ScaledDotProductAttention);
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
auto state() const { auto state() const {
return std::make_tuple(nullptr, scale_, do_causal_); return std::make_tuple(nullptr, scale_, do_causal_);
@ -263,7 +263,7 @@ class AffineQuantize : public Custom {
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override; override;
DEFINE_PRINT(AffineQuantize); DEFINE_NAME(AffineQuantize);
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override; std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
@ -311,7 +311,7 @@ class CustomKernel : public Primitive {
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override; override;
DEFINE_PRINT(CustomKernel); DEFINE_NAME(CustomKernel);
private: private:
std::string source_; std::string source_;

View File

@ -93,7 +93,7 @@ void print_graph(
os << "\n"; os << "\n";
for (auto& arr : tape) { for (auto& arr : tape) {
arr.primitive().print(os); os << arr.primitive().name();
os << " "; os << " ";
print_arrs(arr.inputs()); print_arrs(arr.inputs());
os << " -> "; os << " -> ";
@ -143,7 +143,7 @@ void export_to_dot(
os << "{ "; os << "{ ";
os << x.primitive_id(); os << x.primitive_id();
os << " [label =\""; os << " [label =\"";
x.primitive().print(os); os << x.primitive().name();
os << "\", shape=rectangle]"; os << "\", shape=rectangle]";
os << "; }" << std::endl; os << "; }" << std::endl;
// Arrows to primitive's inputs // Arrows to primitive's inputs

View File

@ -500,7 +500,7 @@ array cross(
void validate_eig( void validate_eig(
const array& a, const array& a,
const StreamOrDevice& stream, const StreamOrDevice& stream,
const std::string fname) { const std::string& fname) {
check_cpu_stream(stream, fname); check_cpu_stream(stream, fname);
check_float_or_complex(a.dtype(), fname); check_float_or_complex(a.dtype(), fname);

View File

@ -181,7 +181,7 @@ std::vector<array> Primitive::jvp(
const std::vector<int>&) { const std::vector<int>&) {
std::ostringstream msg; std::ostringstream msg;
msg << "[Primitive::jvp] Not implemented for "; msg << "[Primitive::jvp] Not implemented for ";
print(msg); msg << name();
msg << "."; msg << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
@ -193,7 +193,7 @@ std::vector<array> Primitive::vjp(
const std::vector<array>&) { const std::vector<array>&) {
std::ostringstream msg; std::ostringstream msg;
msg << "[Primitive::vjp] Not implemented for "; msg << "[Primitive::vjp] Not implemented for ";
print(msg); msg << name();
msg << "."; msg << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
@ -203,7 +203,7 @@ std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
const std::vector<int>&) { const std::vector<int>&) {
std::ostringstream msg; std::ostringstream msg;
msg << "[Primitive::vmap] Not implemented for "; msg << "[Primitive::vmap] Not implemented for ";
print(msg); msg << name();
msg << "."; msg << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
@ -211,7 +211,7 @@ std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
std::vector<Shape> Primitive::output_shapes(const std::vector<array>&) { std::vector<Shape> Primitive::output_shapes(const std::vector<array>&) {
std::ostringstream msg; std::ostringstream msg;
msg << "[Primitive::output_shapes] "; msg << "[Primitive::output_shapes] ";
this->print(msg); msg << name();
msg << " cannot infer output shapes."; msg << " cannot infer output shapes.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
@ -743,26 +743,6 @@ bool BitwiseBinary::is_equivalent(const Primitive& other) const {
return op_ == a_other.op_; return op_ == a_other.op_;
} }
void BitwiseBinary::print(std::ostream& os) {
switch (op_) {
case BitwiseBinary::And:
os << "BitwiseAnd";
break;
case BitwiseBinary::Or:
os << "BitwiseOr";
break;
case BitwiseBinary::Xor:
os << "BitwiseXor";
break;
case BitwiseBinary::LeftShift:
os << "LeftShift";
break;
case BitwiseBinary::RightShift:
os << "RightShift";
break;
}
}
std::pair<std::vector<array>, std::vector<int>> BitwiseBinary::vmap( std::pair<std::vector<array>, std::vector<int>> BitwiseBinary::vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) { const std::vector<int>& axes) {
@ -5375,8 +5355,13 @@ std::pair<std::vector<array>, std::vector<int>> View::vmap(
return {{view(inputs[0], dtype_, stream())}, axes}; return {{view(inputs[0], dtype_, stream())}, axes};
} }
void View::print(std::ostream& os) { const char* View::name() const {
os << "View " << dtype_; if (name_.empty()) {
std::ostringstream os;
os << "View " << dtype_;
name_ = os.str();
}
return name_.c_str();
} }
bool View::is_equivalent(const Primitive& other) const { bool View::is_equivalent(const Primitive& other) const {

File diff suppressed because it is too large Load Diff

View File

@ -33,7 +33,7 @@ class Synchronizer : public Primitive {
void eval_cpu(const std::vector<array>&, std::vector<array>&) override {} void eval_cpu(const std::vector<array>&, std::vector<array>&) override {}
void eval_gpu(const std::vector<array>&, std::vector<array>&) override {} void eval_gpu(const std::vector<array>&, std::vector<array>&) override {}
DEFINE_PRINT(Synchronize); DEFINE_NAME(Synchronize);
}; };
// Initialize the static tracing members from transforms_impl.h // Initialize the static tracing members from transforms_impl.h

View File

@ -514,7 +514,7 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"eigh", "eigh",
[](const mx::array& a, const std::string UPLO, mx::StreamOrDevice s) { [](const mx::array& a, const std::string& UPLO, mx::StreamOrDevice s) {
auto result = mx::linalg::eigh(a, UPLO, s); auto result = mx::linalg::eigh(a, UPLO, s);
return nb::make_tuple(result.first, result.second); return nb::make_tuple(result.first, result.second);
}, },

View File

@ -14,7 +14,7 @@ namespace mx = mlx::core;
namespace nb = nanobind; namespace nb = nanobind;
using namespace nb::literals; using namespace nb::literals;
bool DEPRECATE(const std::string& old_fn, const std::string new_fn) { bool DEPRECATE(const char* old_fn, const char* new_fn) {
std::cerr << old_fn << " is deprecated and will be removed in a future " std::cerr << old_fn << " is deprecated and will be removed in a future "
<< "version. Use " << new_fn << " instead." << std::endl; << "version. Use " << new_fn << " instead." << std::endl;
return true; return true;

View File

@ -3076,7 +3076,7 @@ void init_ops(nb::module_& m) {
std::tuple<int>, std::tuple<int>,
std::pair<int, int>, std::pair<int, int>,
std::vector<std::pair<int, int>>>& pad_width, std::vector<std::pair<int, int>>>& pad_width,
const std::string mode, const std::string& mode,
const ScalarOrArray& constant_value, const ScalarOrArray& constant_value,
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
if (auto pv = std::get_if<int>(&pad_width); pv) { if (auto pv = std::get_if<int>(&pad_width); pv) {