Add Primitive::name and remove Primitive::print (#2365)

This commit is contained in:
Cheng 2025-07-15 06:06:35 +09:00 committed by GitHub
parent 5201df5030
commit d34f887abc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 307 additions and 340 deletions

View File

@ -138,13 +138,13 @@ more concrete:
* representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension.
*/
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
/** Print the primitive. */
void print(std::ostream& os) override {
os << "Axpby";
/** The name of primitive. */
const char* name() const override {
return "Axpby";
}
/** Equivalence check **/

View File

@ -74,9 +74,9 @@ class Axpby : public mx::Primitive {
const std::vector<mx::array>& inputs,
const std::vector<int>& axes) override;
/** Print the primitive. */
void print(std::ostream& os) override {
os << "Axpby";
/** The name of primitive. */
const char* name() const override {
return "Axpby";
}
/** Equivalence check **/

View File

@ -3,16 +3,9 @@
#include <dlfcn.h>
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
std::string get_primitive_string(Primitive* primitive) {
std::ostringstream op_t;
primitive->print(op_t);
return op_t.str();
}
std::filesystem::path current_binary_dir() {
static std::filesystem::path binary_dir = []() {
Dl_info info;

View File

@ -10,8 +10,6 @@
namespace mlx::core {
std::string get_primitive_string(Primitive* primitive);
// Return the directory that contains current shared library.
std::filesystem::path current_binary_dir();

View File

@ -231,7 +231,7 @@ inline void build_kernel(
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
} else {
x.primitive().print(os);
os << x.primitive().name();
os << "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) {
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";

View File

@ -177,7 +177,7 @@ template <typename Op>
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
std::string_view op,
const char* op,
const Stream& s) {
assert(inputs.size() > 1);
const auto& a = inputs[0];
@ -291,7 +291,7 @@ template <typename Op>
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
std::string_view op,
const char* op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
@ -300,11 +300,11 @@ void binary_op_gpu(
binary_op_gpu_inplace<Op>(inputs, out, op, s);
}
#define BINARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = out.primitive().stream(); \
binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
#define BINARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = out.primitive().stream(); \
binary_op_gpu<cu::func>(inputs, out, name(), s); \
}
BINARY_GPU(Add)
@ -328,33 +328,31 @@ BINARY_GPU(Subtract)
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Equal::eval_gpu");
auto& s = out.primitive().stream();
auto op = get_primitive_string(this);
if (equal_nan_) {
binary_op_gpu<cu::NaNEqual>(inputs, out, op, s);
binary_op_gpu<cu::NaNEqual>(inputs, out, name(), s);
} else {
binary_op_gpu<cu::Equal>(inputs, out, op, s);
binary_op_gpu<cu::Equal>(inputs, out, name(), s);
}
}
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
auto& s = out.primitive().stream();
auto op = get_primitive_string(this);
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu<cu::BitwiseAnd>(inputs, out, op, s);
binary_op_gpu<cu::BitwiseAnd>(inputs, out, name(), s);
break;
case BitwiseBinary::Or:
binary_op_gpu<cu::BitwiseOr>(inputs, out, op, s);
binary_op_gpu<cu::BitwiseOr>(inputs, out, name(), s);
break;
case BitwiseBinary::Xor:
binary_op_gpu<cu::BitwiseXor>(inputs, out, op, s);
binary_op_gpu<cu::BitwiseXor>(inputs, out, name(), s);
break;
case BitwiseBinary::LeftShift:
binary_op_gpu<cu::LeftShift>(inputs, out, op, s);
binary_op_gpu<cu::LeftShift>(inputs, out, name(), s);
break;
case BitwiseBinary::RightShift:
binary_op_gpu<cu::RightShift>(inputs, out, op, s);
binary_op_gpu<cu::RightShift>(inputs, out, name(), s);
break;
}
}

View File

@ -184,7 +184,7 @@ template <typename Op>
void binary_two_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
const char* op,
const Stream& s) {
assert(inputs.size() > 1);
const auto& a = inputs[0];
@ -314,7 +314,7 @@ template <typename Op>
void binary_two_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
const char* op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
@ -329,7 +329,7 @@ void DivMod::eval_gpu(
std::vector<array>& outputs) {
nvtx3::scoped_range r("DivMod::eval_gpu");
auto& s = outputs[0].primitive().stream();
binary_two_op_gpu<cu::DivMod>(inputs, outputs, get_primitive_string(this), s);
binary_two_op_gpu<cu::DivMod>(inputs, outputs, name(), s);
}
} // namespace mlx::core

View File

@ -106,9 +106,7 @@ struct FusedKernelBuilder {
value = fmt::format(
"static_cast<{}>(tmp_{})", type, namer.get_name(x.inputs()[0]));
} else {
std::ostringstream ss;
x.primitive().print(ss);
value = ss.str();
value = x.primitive().name();
value += "{}(";
for (size_t i = 0; i < x.inputs().size() - 1; ++i) {
value += fmt::format("tmp_{}, ", namer.get_name(x.inputs()[i]));

View File

@ -102,7 +102,7 @@ template <typename Op>
void unary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string& op,
const char* op,
const Stream& s) {
auto& in = inputs[0];
if (in.size() == 0) {
@ -178,17 +178,17 @@ template <typename Op>
void unary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string& op,
const char* op,
const Stream& s) {
set_unary_output_data(inputs[0], out);
unary_op_gpu_inplace<Op>(inputs, out, op, s);
}
#define UNARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = out.primitive().stream(); \
unary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
#define UNARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = out.primitive().stream(); \
unary_op_gpu<cu::func>(inputs, out, name(), s); \
}
UNARY_GPU(Abs)
@ -224,16 +224,15 @@ UNARY_GPU(Tanh)
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Log::eval_gpu");
auto& s = out.primitive().stream();
auto op = get_primitive_string(this);
switch (base_) {
case Base::e:
unary_op_gpu<cu::Log>(inputs, out, op, s);
unary_op_gpu<cu::Log>(inputs, out, name(), s);
break;
case Base::two:
unary_op_gpu<cu::Log2>(inputs, out, op, s);
unary_op_gpu<cu::Log2>(inputs, out, name(), s);
break;
case Base::ten:
unary_op_gpu<cu::Log10>(inputs, out, op, s);
unary_op_gpu<cu::Log10>(inputs, out, name(), s);
break;
}
}
@ -244,7 +243,7 @@ void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
const auto& in = inputs[0];
auto& s = out.primitive().stream();
if (issubdtype(in.dtype(), inexact)) {
unary_op_gpu<cu::Round>(inputs, out, get_primitive_string(this), s);
unary_op_gpu<cu::Round>(inputs, out, name(), s);
} else {
// No-op integer types
out.copy_shared_buffer(in);

View File

@ -7,20 +7,20 @@
#define BINARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
binary_op_gpu(inputs, out, get_primitive_string(this)); \
binary_op_gpu(inputs, out, name()); \
}
#define BINARY_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
binary_op_gpu(inputs, outputs, get_primitive_string(this)); \
binary_op_gpu(inputs, outputs, name()); \
}
namespace mlx::core {
std::string get_kernel_name(
BinaryOpType bopt,
const std::string& op,
const char* op,
const array& a,
bool large,
int ndim,
@ -65,7 +65,7 @@ std::string get_kernel_name(
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string& op,
const char* op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
@ -165,7 +165,7 @@ void binary_op_gpu_inplace(
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string& op,
const char* op,
const Stream& s) {
assert(inputs.size() == 2);
auto& a = inputs[0];
@ -179,7 +179,7 @@ void binary_op_gpu(
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string& op) {
const char* op) {
auto& s = outputs[0].primitive().stream();
binary_op_gpu(inputs, outputs, op, s);
}
@ -187,7 +187,7 @@ void binary_op_gpu(
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string& op,
const char* op,
const Stream& s) {
std::vector<array> outputs = {out};
binary_op_gpu_inplace(inputs, outputs, op, s);
@ -196,7 +196,7 @@ void binary_op_gpu_inplace(
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string& op,
const char* op,
const Stream& s) {
assert(inputs.size() == 2);
auto& a = inputs[0];
@ -209,7 +209,7 @@ void binary_op_gpu(
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string& op) {
const char* op) {
auto& s = out.primitive().stream();
binary_op_gpu(inputs, out, op, s);
}
@ -237,19 +237,19 @@ BINARY_GPU(Subtract)
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu(inputs, out, get_primitive_string(this));
binary_op_gpu(inputs, out, name());
break;
case BitwiseBinary::Or:
binary_op_gpu(inputs, out, get_primitive_string(this));
binary_op_gpu(inputs, out, name());
break;
case BitwiseBinary::Xor:
binary_op_gpu(inputs, out, get_primitive_string(this));
binary_op_gpu(inputs, out, name());
break;
case BitwiseBinary::LeftShift:
binary_op_gpu(inputs, out, get_primitive_string(this));
binary_op_gpu(inputs, out, name());
break;
case BitwiseBinary::RightShift:
binary_op_gpu(inputs, out, get_primitive_string(this));
binary_op_gpu(inputs, out, name());
break;
}
}

View File

@ -9,25 +9,25 @@ namespace mlx::core {
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string& op,
const char* op,
const Stream& s);
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string& op,
const char* op,
const Stream& s);
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string& op,
const char* op,
const Stream& s);
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string& op,
const char* op,
const Stream& s);
} // namespace mlx::core

View File

@ -212,9 +212,7 @@ inline void build_kernel(
get_type_string(x.dtype()),
namer.get_name(x.inputs()[0]));
} else {
std::ostringstream ss;
x.primitive().print(ss);
os += ss.str();
os += x.primitive().name();
os += "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) {
os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));

View File

@ -8,12 +8,6 @@ using namespace fmt::literals;
namespace mlx::core {
std::string op_name(const array& arr) {
std::ostringstream op_t;
arr.primitive().print(op_t);
return op_t.str();
}
MTL::ComputePipelineState* get_arange_kernel(
metal::Device& d,
const std::string& kernel_name,
@ -33,7 +27,7 @@ MTL::ComputePipelineState* get_unary_kernel(
const std::string& kernel_name,
Dtype in_type,
Dtype out_type,
const std::string op) {
const char* op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
auto in_t = get_type_string(in_type);
@ -58,10 +52,10 @@ MTL::ComputePipelineState* get_unary_kernel(
}
void append_binary_kernels(
const std::string lib_name,
const std::string& lib_name,
Dtype in_type,
Dtype out_type,
const std::string op,
const char* op,
std::string& kernel_source) {
const std::array<std::pair<std::string, std::string>, 7> kernel_types = {{
{"ss", "binary_ss"},
@ -112,7 +106,7 @@ MTL::ComputePipelineState* get_binary_kernel(
const std::string& kernel_name,
Dtype in_type,
Dtype out_type,
const std::string op) {
const char* op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source;
@ -129,7 +123,7 @@ MTL::ComputePipelineState* get_binary_two_kernel(
const std::string& kernel_name,
Dtype in_type,
Dtype out_type,
const std::string op) {
const char* op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils();
@ -144,7 +138,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype type,
const std::string op) {
const char* op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
auto t_str = get_type_string(type);

View File

@ -19,27 +19,27 @@ MTL::ComputePipelineState* get_unary_kernel(
const std::string& kernel_name,
Dtype in_type,
Dtype out_type,
const std::string op);
const char* op);
MTL::ComputePipelineState* get_binary_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype in_type,
Dtype out_type,
const std::string op);
const char* op);
MTL::ComputePipelineState* get_binary_two_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype in_type,
Dtype out_type,
const std::string op);
const char* op);
MTL::ComputePipelineState* get_ternary_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype type,
const std::string op);
const char* op);
MTL::ComputePipelineState* get_copy_kernel(
metal::Device& d,
@ -257,8 +257,10 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
// Create a GPU kernel template definition for JIT compilation
template <typename... Args>
std::string
get_template_definition(std::string name, std::string func, Args... args) {
std::string get_template_definition(
std::string_view name,
std::string_view func,
Args... args) {
std::ostringstream s;
s << func << "<";
bool first = true;

View File

@ -18,7 +18,7 @@ MTL::ComputePipelineState* get_unary_kernel(
const std::string& kernel_name,
Dtype,
Dtype,
const std::string) {
const char*) {
return d.get_kernel(kernel_name);
}
@ -27,7 +27,7 @@ MTL::ComputePipelineState* get_binary_kernel(
const std::string& kernel_name,
Dtype,
Dtype,
const std::string) {
const char*) {
return d.get_kernel(kernel_name);
}
@ -36,7 +36,7 @@ MTL::ComputePipelineState* get_binary_two_kernel(
const std::string& kernel_name,
Dtype,
Dtype,
const std::string) {
const char*) {
return d.get_kernel(kernel_name);
}
@ -44,7 +44,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype,
const std::string) {
const char*) {
return d.get_kernel(kernel_name);
}

View File

@ -11,7 +11,7 @@ namespace mlx::core {
void ternary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string op,
const char* op,
const Stream& s) {
assert(inputs.size() == 3);
auto& a = inputs[0];
@ -128,7 +128,7 @@ void ternary_op_gpu_inplace(
void ternary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op,
const char* op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
@ -141,13 +141,13 @@ void ternary_op_gpu(
void ternary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op) {
const char* op) {
auto& s = out.primitive().stream();
ternary_op_gpu(inputs, out, op, s);
}
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
ternary_op_gpu(inputs, out, get_primitive_string(this));
ternary_op_gpu(inputs, out, name());
}
} // namespace mlx::core

View File

@ -9,13 +9,13 @@ namespace mlx::core {
void ternary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op,
const char* op,
const Stream& s);
void ternary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string op,
const char* op,
const Stream& s);
} // namespace mlx::core

View File

@ -8,7 +8,7 @@
#define UNARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
unary_op_gpu(inputs, out, get_primitive_string(this)); \
unary_op_gpu(inputs, out, name()); \
}
namespace mlx::core {
@ -16,7 +16,7 @@ namespace mlx::core {
void unary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string op,
const char* op,
const Stream& s) {
auto& in = inputs[0];
bool contig = in.flags().contiguous;
@ -98,7 +98,7 @@ void unary_op_gpu_inplace(
void unary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op,
const char* op,
const Stream& s) {
set_unary_output_data(inputs[0], out);
unary_op_gpu_inplace(inputs, out, op, s);
@ -107,7 +107,7 @@ void unary_op_gpu(
void unary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op) {
const char* op) {
auto& s = out.primitive().stream();
unary_op_gpu(inputs, out, op, s);
}
@ -146,13 +146,13 @@ UNARY_GPU(Tanh)
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
switch (base_) {
case Base::e:
unary_op_gpu(inputs, out, get_primitive_string(this));
unary_op_gpu(inputs, out, name());
break;
case Base::two:
unary_op_gpu(inputs, out, get_primitive_string(this));
unary_op_gpu(inputs, out, name());
break;
case Base::ten:
unary_op_gpu(inputs, out, get_primitive_string(this));
unary_op_gpu(inputs, out, name());
break;
}
}
@ -161,7 +161,7 @@ void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_op_gpu(inputs, out, get_primitive_string(this));
unary_op_gpu(inputs, out, name());
} else {
// No-op integer types
out.copy_shared_buffer(in);

View File

@ -9,13 +9,13 @@ namespace mlx::core {
void unary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op,
const char* op,
const Stream& s);
void unary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string op,
const char* op,
const Stream& s);
} // namespace mlx::core

View File

@ -40,7 +40,7 @@ inline void debug_set_primitive_buffer_label(
if (auto cbuf_label = command_buffer->label(); cbuf_label) {
label << cbuf_label->utf8String();
}
primitive.print(label);
label << primitive.name();
command_buffer->setLabel(make_string(label));
#endif
}

View File

@ -107,7 +107,7 @@ Compiled::Compiled(
// name and type of output
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
// computation performed
a.primitive().print(os);
os << a.primitive().name();
// name of inputs to the function
for (auto& inp : a.inputs()) {
os << namer.get_name(inp);
@ -170,11 +170,16 @@ bool Compiled::is_equivalent(const Primitive& other) const {
});
}
void Compiled::print(std::ostream& os) {
os << "Compiled";
for (auto& a : tape_) {
a.primitive().print(os);
const char* Compiled::name() const {
if (name_.empty()) {
std::ostringstream os;
os << "Compiled";
for (auto& a : tape_) {
os << a.primitive().name();
}
name_ = os.str();
}
return name_.c_str();
}
std::vector<Shape> Compiled::output_shapes(const std::vector<array>& inputs) {

View File

@ -45,27 +45,22 @@ class AllReduce : public DistPrimitive {
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
void print(std::ostream& os) override {
const char* name() const override {
switch (reduce_type_) {
case And:
os << "And";
return "And AllReduce";
case Or:
os << "And";
break;
return "Or AllReduce";
case Sum:
os << "Sum";
break;
return "Sum AllReduce";
case Prod:
os << "Prod";
break;
return "Prod AllReduce";
case Min:
os << "Min";
break;
return "Min AllReduce";
case Max:
os << "Max";
break;
return "Max AllReduce";
}
os << " AllReduce";
return "<unknwon AllReduce>";
}
private:
@ -94,7 +89,7 @@ class AllGather : public DistPrimitive {
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_PRINT(AllGather);
DEFINE_NAME(AllGather);
};
class Send : public DistPrimitive {
@ -110,7 +105,7 @@ class Send : public DistPrimitive {
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
DEFINE_PRINT(Send);
DEFINE_NAME(Send);
private:
int dst_;
@ -126,7 +121,7 @@ class Recv : public DistPrimitive {
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(Recv);
DEFINE_NAME(Recv);
private:
int src_;

View File

@ -354,9 +354,7 @@ struct PrimitiveFactory {
void save(Writer& os, const std::shared_ptr<Primitive>& p) {
serialize(os, p->stream());
std::ostringstream pout;
p->print(pout);
auto name = pout.str();
std::string name = p->name();
name = name.substr(0, name.find(' '));
if (auto it = name_remap.find(name); it != name_remap.end()) {
name = it->second;

View File

@ -58,7 +58,7 @@ class RMSNorm : public Custom {
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_PRINT(RMSNorm)
DEFINE_NAME(RMSNorm)
bool is_equivalent(const Primitive& other) const override;
DEFINE_INPUT_OUTPUT_SHAPE()
@ -85,7 +85,7 @@ class RMSNormVJP : public Custom {
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(RMSNormVJP)
DEFINE_NAME(RMSNormVJP)
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_pair(nullptr, eps_);
@ -118,7 +118,7 @@ class LayerNorm : public Custom {
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_PRINT(LayerNorm)
DEFINE_NAME(LayerNorm)
bool is_equivalent(const Primitive& other) const override;
DEFINE_INPUT_OUTPUT_SHAPE()
auto state() const {
@ -144,7 +144,7 @@ class LayerNormVJP : public Custom {
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(LayerNormVJP)
DEFINE_NAME(LayerNormVJP)
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_pair(nullptr, eps_);
@ -186,7 +186,7 @@ class RoPE : public Custom {
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_PRINT(RoPE)
DEFINE_NAME(RoPE)
bool is_equivalent(const Primitive& other) const override;
DEFINE_INPUT_OUTPUT_SHAPE()
auto state() const {
@ -233,7 +233,7 @@ class ScaledDotProductAttention : public Custom {
void eval_gpu(const std::vector<array>& inputs, array& out);
bool is_equivalent(const Primitive& other) const override;
DEFINE_PRINT(ScaledDotProductAttention);
DEFINE_NAME(ScaledDotProductAttention);
DEFINE_INPUT_OUTPUT_SHAPE()
auto state() const {
return std::make_tuple(nullptr, scale_, do_causal_);
@ -263,7 +263,7 @@ class AffineQuantize : public Custom {
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(AffineQuantize);
DEFINE_NAME(AffineQuantize);
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
@ -311,7 +311,7 @@ class CustomKernel : public Primitive {
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(CustomKernel);
DEFINE_NAME(CustomKernel);
private:
std::string source_;

View File

@ -93,7 +93,7 @@ void print_graph(
os << "\n";
for (auto& arr : tape) {
arr.primitive().print(os);
os << arr.primitive().name();
os << " ";
print_arrs(arr.inputs());
os << " -> ";
@ -143,7 +143,7 @@ void export_to_dot(
os << "{ ";
os << x.primitive_id();
os << " [label =\"";
x.primitive().print(os);
os << x.primitive().name();
os << "\", shape=rectangle]";
os << "; }" << std::endl;
// Arrows to primitive's inputs

View File

@ -500,7 +500,7 @@ array cross(
void validate_eig(
const array& a,
const StreamOrDevice& stream,
const std::string fname) {
const std::string& fname) {
check_cpu_stream(stream, fname);
check_float_or_complex(a.dtype(), fname);

View File

@ -181,7 +181,7 @@ std::vector<array> Primitive::jvp(
const std::vector<int>&) {
std::ostringstream msg;
msg << "[Primitive::jvp] Not implemented for ";
print(msg);
msg << name();
msg << ".";
throw std::invalid_argument(msg.str());
}
@ -193,7 +193,7 @@ std::vector<array> Primitive::vjp(
const std::vector<array>&) {
std::ostringstream msg;
msg << "[Primitive::vjp] Not implemented for ";
print(msg);
msg << name();
msg << ".";
throw std::invalid_argument(msg.str());
}
@ -203,7 +203,7 @@ std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
const std::vector<int>&) {
std::ostringstream msg;
msg << "[Primitive::vmap] Not implemented for ";
print(msg);
msg << name();
msg << ".";
throw std::invalid_argument(msg.str());
}
@ -211,7 +211,7 @@ std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
std::vector<Shape> Primitive::output_shapes(const std::vector<array>&) {
std::ostringstream msg;
msg << "[Primitive::output_shapes] ";
this->print(msg);
msg << name();
msg << " cannot infer output shapes.";
throw std::invalid_argument(msg.str());
}
@ -743,26 +743,6 @@ bool BitwiseBinary::is_equivalent(const Primitive& other) const {
return op_ == a_other.op_;
}
void BitwiseBinary::print(std::ostream& os) {
switch (op_) {
case BitwiseBinary::And:
os << "BitwiseAnd";
break;
case BitwiseBinary::Or:
os << "BitwiseOr";
break;
case BitwiseBinary::Xor:
os << "BitwiseXor";
break;
case BitwiseBinary::LeftShift:
os << "LeftShift";
break;
case BitwiseBinary::RightShift:
os << "RightShift";
break;
}
}
std::pair<std::vector<array>, std::vector<int>> BitwiseBinary::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
@ -5375,8 +5355,13 @@ std::pair<std::vector<array>, std::vector<int>> View::vmap(
return {{view(inputs[0], dtype_, stream())}, axes};
}
void View::print(std::ostream& os) {
os << "View " << dtype_;
const char* View::name() const {
if (name_.empty()) {
std::ostringstream os;
os << "View " << dtype_;
name_ = os.str();
}
return name_.c_str();
}
bool View::is_equivalent(const Primitive& other) const {

File diff suppressed because it is too large Load Diff

View File

@ -33,7 +33,7 @@ class Synchronizer : public Primitive {
void eval_cpu(const std::vector<array>&, std::vector<array>&) override {}
void eval_gpu(const std::vector<array>&, std::vector<array>&) override {}
DEFINE_PRINT(Synchronize);
DEFINE_NAME(Synchronize);
};
// Initialize the static tracing members from transforms_impl.h

View File

@ -514,7 +514,7 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc");
m.def(
"eigh",
[](const mx::array& a, const std::string UPLO, mx::StreamOrDevice s) {
[](const mx::array& a, const std::string& UPLO, mx::StreamOrDevice s) {
auto result = mx::linalg::eigh(a, UPLO, s);
return nb::make_tuple(result.first, result.second);
},

View File

@ -14,7 +14,7 @@ namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
bool DEPRECATE(const std::string& old_fn, const std::string new_fn) {
bool DEPRECATE(const char* old_fn, const char* new_fn) {
std::cerr << old_fn << " is deprecated and will be removed in a future "
<< "version. Use " << new_fn << " instead." << std::endl;
return true;

View File

@ -3076,7 +3076,7 @@ void init_ops(nb::module_& m) {
std::tuple<int>,
std::pair<int, int>,
std::vector<std::pair<int, int>>>& pad_width,
const std::string mode,
const std::string& mode,
const ScalarOrArray& constant_value,
mx::StreamOrDevice s) {
if (auto pv = std::get_if<int>(&pad_width); pv) {