diff --git a/docs/src/dev/extensions.rst b/docs/src/dev/extensions.rst index 03f1c2163..5a4de8123 100644 --- a/docs/src/dev/extensions.rst +++ b/docs/src/dev/extensions.rst @@ -138,13 +138,13 @@ more concrete: * representing the vectorized computation and the axis which * corresponds to the output vectorized dimension. */ - virtual std::pair, std::vector> vmap( + std::pair, std::vector> vmap( const std::vector& inputs, const std::vector& axes) override; - /** Print the primitive. */ - void print(std::ostream& os) override { - os << "Axpby"; + /** The name of primitive. */ + const char* name() const override { + return "Axpby"; } /** Equivalence check **/ diff --git a/examples/extensions/axpby/axpby.h b/examples/extensions/axpby/axpby.h index 26f80961c..e6da491f8 100644 --- a/examples/extensions/axpby/axpby.h +++ b/examples/extensions/axpby/axpby.h @@ -74,9 +74,9 @@ class Axpby : public mx::Primitive { const std::vector& inputs, const std::vector& axes) override; - /** Print the primitive. */ - void print(std::ostream& os) override { - os << "Axpby"; + /** The name of primitive. */ + const char* name() const override { + return "Axpby"; } /** Equivalence check **/ diff --git a/mlx/backend/common/utils.cpp b/mlx/backend/common/utils.cpp index 942f9576e..ae169e35e 100644 --- a/mlx/backend/common/utils.cpp +++ b/mlx/backend/common/utils.cpp @@ -3,16 +3,9 @@ #include #include "mlx/backend/common/utils.h" -#include "mlx/primitives.h" namespace mlx::core { -std::string get_primitive_string(Primitive* primitive) { - std::ostringstream op_t; - primitive->print(op_t); - return op_t.str(); -} - std::filesystem::path current_binary_dir() { static std::filesystem::path binary_dir = []() { Dl_info info; diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index 543868e36..0f9846086 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -10,8 +10,6 @@ namespace mlx::core { -std::string get_primitive_string(Primitive* primitive); - // Return the directory that contains current shared library. std::filesystem::path current_binary_dir(); diff --git a/mlx/backend/cpu/compiled.cpp b/mlx/backend/cpu/compiled.cpp index d0bfb4f45..d85114987 100644 --- a/mlx/backend/cpu/compiled.cpp +++ b/mlx/backend/cpu/compiled.cpp @@ -231,7 +231,7 @@ inline void build_kernel( os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_" << namer.get_name(x.inputs()[0]) << ");" << std::endl; } else { - x.primitive().print(os); + os << x.primitive().name(); os << "()("; for (int i = 0; i < x.inputs().size() - 1; i++) { os << "tmp_" << namer.get_name(x.inputs()[i]) << ", "; diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index fc5b8c496..c8586e638 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -177,7 +177,7 @@ template void binary_op_gpu_inplace( const std::vector& inputs, array& out, - std::string_view op, + const char* op, const Stream& s) { assert(inputs.size() > 1); const auto& a = inputs[0]; @@ -291,7 +291,7 @@ template void binary_op_gpu( const std::vector& inputs, array& out, - std::string_view op, + const char* op, const Stream& s) { auto& a = inputs[0]; auto& b = inputs[1]; @@ -300,11 +300,11 @@ void binary_op_gpu( binary_op_gpu_inplace(inputs, out, op, s); } -#define BINARY_GPU(func) \ - void func::eval_gpu(const std::vector& inputs, array& out) { \ - nvtx3::scoped_range r(#func "::eval_gpu"); \ - auto& s = out.primitive().stream(); \ - binary_op_gpu(inputs, out, get_primitive_string(this), s); \ +#define BINARY_GPU(func) \ + void func::eval_gpu(const std::vector& inputs, array& out) { \ + nvtx3::scoped_range r(#func "::eval_gpu"); \ + auto& s = out.primitive().stream(); \ + binary_op_gpu(inputs, out, name(), s); \ } BINARY_GPU(Add) @@ -328,33 +328,31 @@ BINARY_GPU(Subtract) void Equal::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Equal::eval_gpu"); auto& s = out.primitive().stream(); - auto op = get_primitive_string(this); if (equal_nan_) { - binary_op_gpu(inputs, out, op, s); + binary_op_gpu(inputs, out, name(), s); } else { - binary_op_gpu(inputs, out, op, s); + binary_op_gpu(inputs, out, name(), s); } } void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("BitwiseBinary::eval_gpu"); auto& s = out.primitive().stream(); - auto op = get_primitive_string(this); switch (op_) { case BitwiseBinary::And: - binary_op_gpu(inputs, out, op, s); + binary_op_gpu(inputs, out, name(), s); break; case BitwiseBinary::Or: - binary_op_gpu(inputs, out, op, s); + binary_op_gpu(inputs, out, name(), s); break; case BitwiseBinary::Xor: - binary_op_gpu(inputs, out, op, s); + binary_op_gpu(inputs, out, name(), s); break; case BitwiseBinary::LeftShift: - binary_op_gpu(inputs, out, op, s); + binary_op_gpu(inputs, out, name(), s); break; case BitwiseBinary::RightShift: - binary_op_gpu(inputs, out, op, s); + binary_op_gpu(inputs, out, name(), s); break; } } diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index 4b6e24581..0918c579f 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -184,7 +184,7 @@ template void binary_two_op_gpu_inplace( const std::vector& inputs, std::vector& outputs, - std::string_view op, + const char* op, const Stream& s) { assert(inputs.size() > 1); const auto& a = inputs[0]; @@ -314,7 +314,7 @@ template void binary_two_op_gpu( const std::vector& inputs, std::vector& outputs, - std::string_view op, + const char* op, const Stream& s) { auto& a = inputs[0]; auto& b = inputs[1]; @@ -329,7 +329,7 @@ void DivMod::eval_gpu( std::vector& outputs) { nvtx3::scoped_range r("DivMod::eval_gpu"); auto& s = outputs[0].primitive().stream(); - binary_two_op_gpu(inputs, outputs, get_primitive_string(this), s); + binary_two_op_gpu(inputs, outputs, name(), s); } } // namespace mlx::core diff --git a/mlx/backend/cuda/compiled.cpp b/mlx/backend/cuda/compiled.cpp index 21257e5dd..2f3990b90 100644 --- a/mlx/backend/cuda/compiled.cpp +++ b/mlx/backend/cuda/compiled.cpp @@ -106,9 +106,7 @@ struct FusedKernelBuilder { value = fmt::format( "static_cast<{}>(tmp_{})", type, namer.get_name(x.inputs()[0])); } else { - std::ostringstream ss; - x.primitive().print(ss); - value = ss.str(); + value = x.primitive().name(); value += "{}("; for (size_t i = 0; i < x.inputs().size() - 1; ++i) { value += fmt::format("tmp_{}, ", namer.get_name(x.inputs()[i])); diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index 1fe1b557b..0d2754ef0 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -102,7 +102,7 @@ template void unary_op_gpu_inplace( const std::vector& inputs, array& out, - const std::string& op, + const char* op, const Stream& s) { auto& in = inputs[0]; if (in.size() == 0) { @@ -178,17 +178,17 @@ template void unary_op_gpu( const std::vector& inputs, array& out, - const std::string& op, + const char* op, const Stream& s) { set_unary_output_data(inputs[0], out); unary_op_gpu_inplace(inputs, out, op, s); } -#define UNARY_GPU(func) \ - void func::eval_gpu(const std::vector& inputs, array& out) { \ - nvtx3::scoped_range r(#func "::eval_gpu"); \ - auto& s = out.primitive().stream(); \ - unary_op_gpu(inputs, out, get_primitive_string(this), s); \ +#define UNARY_GPU(func) \ + void func::eval_gpu(const std::vector& inputs, array& out) { \ + nvtx3::scoped_range r(#func "::eval_gpu"); \ + auto& s = out.primitive().stream(); \ + unary_op_gpu(inputs, out, name(), s); \ } UNARY_GPU(Abs) @@ -224,16 +224,15 @@ UNARY_GPU(Tanh) void Log::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Log::eval_gpu"); auto& s = out.primitive().stream(); - auto op = get_primitive_string(this); switch (base_) { case Base::e: - unary_op_gpu(inputs, out, op, s); + unary_op_gpu(inputs, out, name(), s); break; case Base::two: - unary_op_gpu(inputs, out, op, s); + unary_op_gpu(inputs, out, name(), s); break; case Base::ten: - unary_op_gpu(inputs, out, op, s); + unary_op_gpu(inputs, out, name(), s); break; } } @@ -244,7 +243,7 @@ void Round::eval_gpu(const std::vector& inputs, array& out) { const auto& in = inputs[0]; auto& s = out.primitive().stream(); if (issubdtype(in.dtype(), inexact)) { - unary_op_gpu(inputs, out, get_primitive_string(this), s); + unary_op_gpu(inputs, out, name(), s); } else { // No-op integer types out.copy_shared_buffer(in); diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp index 54aaf153c..8c0e8c333 100644 --- a/mlx/backend/metal/binary.cpp +++ b/mlx/backend/metal/binary.cpp @@ -7,20 +7,20 @@ #define BINARY_GPU(func) \ void func::eval_gpu(const std::vector& inputs, array& out) { \ - binary_op_gpu(inputs, out, get_primitive_string(this)); \ + binary_op_gpu(inputs, out, name()); \ } #define BINARY_GPU_MULTI(func) \ void func::eval_gpu( \ const std::vector& inputs, std::vector& outputs) { \ - binary_op_gpu(inputs, outputs, get_primitive_string(this)); \ + binary_op_gpu(inputs, outputs, name()); \ } namespace mlx::core { std::string get_kernel_name( BinaryOpType bopt, - const std::string& op, + const char* op, const array& a, bool large, int ndim, @@ -65,7 +65,7 @@ std::string get_kernel_name( void binary_op_gpu_inplace( const std::vector& inputs, std::vector& outputs, - const std::string& op, + const char* op, const Stream& s) { auto& a = inputs[0]; auto& b = inputs[1]; @@ -165,7 +165,7 @@ void binary_op_gpu_inplace( void binary_op_gpu( const std::vector& inputs, std::vector& outputs, - const std::string& op, + const char* op, const Stream& s) { assert(inputs.size() == 2); auto& a = inputs[0]; @@ -179,7 +179,7 @@ void binary_op_gpu( void binary_op_gpu( const std::vector& inputs, std::vector& outputs, - const std::string& op) { + const char* op) { auto& s = outputs[0].primitive().stream(); binary_op_gpu(inputs, outputs, op, s); } @@ -187,7 +187,7 @@ void binary_op_gpu( void binary_op_gpu_inplace( const std::vector& inputs, array& out, - const std::string& op, + const char* op, const Stream& s) { std::vector outputs = {out}; binary_op_gpu_inplace(inputs, outputs, op, s); @@ -196,7 +196,7 @@ void binary_op_gpu_inplace( void binary_op_gpu( const std::vector& inputs, array& out, - const std::string& op, + const char* op, const Stream& s) { assert(inputs.size() == 2); auto& a = inputs[0]; @@ -209,7 +209,7 @@ void binary_op_gpu( void binary_op_gpu( const std::vector& inputs, array& out, - const std::string& op) { + const char* op) { auto& s = out.primitive().stream(); binary_op_gpu(inputs, out, op, s); } @@ -237,19 +237,19 @@ BINARY_GPU(Subtract) void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) { switch (op_) { case BitwiseBinary::And: - binary_op_gpu(inputs, out, get_primitive_string(this)); + binary_op_gpu(inputs, out, name()); break; case BitwiseBinary::Or: - binary_op_gpu(inputs, out, get_primitive_string(this)); + binary_op_gpu(inputs, out, name()); break; case BitwiseBinary::Xor: - binary_op_gpu(inputs, out, get_primitive_string(this)); + binary_op_gpu(inputs, out, name()); break; case BitwiseBinary::LeftShift: - binary_op_gpu(inputs, out, get_primitive_string(this)); + binary_op_gpu(inputs, out, name()); break; case BitwiseBinary::RightShift: - binary_op_gpu(inputs, out, get_primitive_string(this)); + binary_op_gpu(inputs, out, name()); break; } } diff --git a/mlx/backend/metal/binary.h b/mlx/backend/metal/binary.h index 8552c1e07..0341a2f83 100644 --- a/mlx/backend/metal/binary.h +++ b/mlx/backend/metal/binary.h @@ -9,25 +9,25 @@ namespace mlx::core { void binary_op_gpu( const std::vector& inputs, std::vector& outputs, - const std::string& op, + const char* op, const Stream& s); void binary_op_gpu( const std::vector& inputs, array& out, - const std::string& op, + const char* op, const Stream& s); void binary_op_gpu_inplace( const std::vector& inputs, std::vector& outputs, - const std::string& op, + const char* op, const Stream& s); void binary_op_gpu_inplace( const std::vector& inputs, array& out, - const std::string& op, + const char* op, const Stream& s); } // namespace mlx::core diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index 88edc6baa..eb51ab750 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -212,9 +212,7 @@ inline void build_kernel( get_type_string(x.dtype()), namer.get_name(x.inputs()[0])); } else { - std::ostringstream ss; - x.primitive().print(ss); - os += ss.str(); + os += x.primitive().name(); os += "()("; for (int i = 0; i < x.inputs().size() - 1; i++) { os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i])); diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index fd0e0db09..6ae72e0aa 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -8,12 +8,6 @@ using namespace fmt::literals; namespace mlx::core { -std::string op_name(const array& arr) { - std::ostringstream op_t; - arr.primitive().print(op_t); - return op_t.str(); -} - MTL::ComputePipelineState* get_arange_kernel( metal::Device& d, const std::string& kernel_name, @@ -33,7 +27,7 @@ MTL::ComputePipelineState* get_unary_kernel( const std::string& kernel_name, Dtype in_type, Dtype out_type, - const std::string op) { + const char* op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { auto in_t = get_type_string(in_type); @@ -58,10 +52,10 @@ MTL::ComputePipelineState* get_unary_kernel( } void append_binary_kernels( - const std::string lib_name, + const std::string& lib_name, Dtype in_type, Dtype out_type, - const std::string op, + const char* op, std::string& kernel_source) { const std::array, 7> kernel_types = {{ {"ss", "binary_ss"}, @@ -112,7 +106,7 @@ MTL::ComputePipelineState* get_binary_kernel( const std::string& kernel_name, Dtype in_type, Dtype out_type, - const std::string op) { + const char* op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { std::string kernel_source; @@ -129,7 +123,7 @@ MTL::ComputePipelineState* get_binary_two_kernel( const std::string& kernel_name, Dtype in_type, Dtype out_type, - const std::string op) { + const char* op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { std::string kernel_source = metal::utils(); @@ -144,7 +138,7 @@ MTL::ComputePipelineState* get_ternary_kernel( metal::Device& d, const std::string& kernel_name, Dtype type, - const std::string op) { + const char* op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { auto t_str = get_type_string(type); diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index 794c67bdc..ca29ca52e 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -19,27 +19,27 @@ MTL::ComputePipelineState* get_unary_kernel( const std::string& kernel_name, Dtype in_type, Dtype out_type, - const std::string op); + const char* op); MTL::ComputePipelineState* get_binary_kernel( metal::Device& d, const std::string& kernel_name, Dtype in_type, Dtype out_type, - const std::string op); + const char* op); MTL::ComputePipelineState* get_binary_two_kernel( metal::Device& d, const std::string& kernel_name, Dtype in_type, Dtype out_type, - const std::string op); + const char* op); MTL::ComputePipelineState* get_ternary_kernel( metal::Device& d, const std::string& kernel_name, Dtype type, - const std::string op); + const char* op); MTL::ComputePipelineState* get_copy_kernel( metal::Device& d, @@ -257,8 +257,10 @@ MTL::ComputePipelineState* get_gather_qmm_kernel( // Create a GPU kernel template definition for JIT compilation template -std::string -get_template_definition(std::string name, std::string func, Args... args) { +std::string get_template_definition( + std::string_view name, + std::string_view func, + Args... args) { std::ostringstream s; s << func << "<"; bool first = true; diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index 32d3e75f7..a689a793e 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -18,7 +18,7 @@ MTL::ComputePipelineState* get_unary_kernel( const std::string& kernel_name, Dtype, Dtype, - const std::string) { + const char*) { return d.get_kernel(kernel_name); } @@ -27,7 +27,7 @@ MTL::ComputePipelineState* get_binary_kernel( const std::string& kernel_name, Dtype, Dtype, - const std::string) { + const char*) { return d.get_kernel(kernel_name); } @@ -36,7 +36,7 @@ MTL::ComputePipelineState* get_binary_two_kernel( const std::string& kernel_name, Dtype, Dtype, - const std::string) { + const char*) { return d.get_kernel(kernel_name); } @@ -44,7 +44,7 @@ MTL::ComputePipelineState* get_ternary_kernel( metal::Device& d, const std::string& kernel_name, Dtype, - const std::string) { + const char*) { return d.get_kernel(kernel_name); } diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index 22f2a1985..b2b9e3337 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -11,7 +11,7 @@ namespace mlx::core { void ternary_op_gpu_inplace( const std::vector& inputs, array& out, - const std::string op, + const char* op, const Stream& s) { assert(inputs.size() == 3); auto& a = inputs[0]; @@ -128,7 +128,7 @@ void ternary_op_gpu_inplace( void ternary_op_gpu( const std::vector& inputs, array& out, - const std::string op, + const char* op, const Stream& s) { auto& a = inputs[0]; auto& b = inputs[1]; @@ -141,13 +141,13 @@ void ternary_op_gpu( void ternary_op_gpu( const std::vector& inputs, array& out, - const std::string op) { + const char* op) { auto& s = out.primitive().stream(); ternary_op_gpu(inputs, out, op, s); } void Select::eval_gpu(const std::vector& inputs, array& out) { - ternary_op_gpu(inputs, out, get_primitive_string(this)); + ternary_op_gpu(inputs, out, name()); } } // namespace mlx::core diff --git a/mlx/backend/metal/ternary.h b/mlx/backend/metal/ternary.h index 0834140b8..91c6fbbeb 100644 --- a/mlx/backend/metal/ternary.h +++ b/mlx/backend/metal/ternary.h @@ -9,13 +9,13 @@ namespace mlx::core { void ternary_op_gpu( const std::vector& inputs, array& out, - const std::string op, + const char* op, const Stream& s); void ternary_op_gpu_inplace( const std::vector& inputs, array& out, - const std::string op, + const char* op, const Stream& s); } // namespace mlx::core diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index 0b118b72f..48f85635b 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -8,7 +8,7 @@ #define UNARY_GPU(func) \ void func::eval_gpu(const std::vector& inputs, array& out) { \ - unary_op_gpu(inputs, out, get_primitive_string(this)); \ + unary_op_gpu(inputs, out, name()); \ } namespace mlx::core { @@ -16,7 +16,7 @@ namespace mlx::core { void unary_op_gpu_inplace( const std::vector& inputs, array& out, - const std::string op, + const char* op, const Stream& s) { auto& in = inputs[0]; bool contig = in.flags().contiguous; @@ -98,7 +98,7 @@ void unary_op_gpu_inplace( void unary_op_gpu( const std::vector& inputs, array& out, - const std::string op, + const char* op, const Stream& s) { set_unary_output_data(inputs[0], out); unary_op_gpu_inplace(inputs, out, op, s); @@ -107,7 +107,7 @@ void unary_op_gpu( void unary_op_gpu( const std::vector& inputs, array& out, - const std::string op) { + const char* op) { auto& s = out.primitive().stream(); unary_op_gpu(inputs, out, op, s); } @@ -146,13 +146,13 @@ UNARY_GPU(Tanh) void Log::eval_gpu(const std::vector& inputs, array& out) { switch (base_) { case Base::e: - unary_op_gpu(inputs, out, get_primitive_string(this)); + unary_op_gpu(inputs, out, name()); break; case Base::two: - unary_op_gpu(inputs, out, get_primitive_string(this)); + unary_op_gpu(inputs, out, name()); break; case Base::ten: - unary_op_gpu(inputs, out, get_primitive_string(this)); + unary_op_gpu(inputs, out, name()); break; } } @@ -161,7 +161,7 @@ void Round::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (issubdtype(in.dtype(), inexact)) { - unary_op_gpu(inputs, out, get_primitive_string(this)); + unary_op_gpu(inputs, out, name()); } else { // No-op integer types out.copy_shared_buffer(in); diff --git a/mlx/backend/metal/unary.h b/mlx/backend/metal/unary.h index 19057076b..1d6ecf027 100644 --- a/mlx/backend/metal/unary.h +++ b/mlx/backend/metal/unary.h @@ -9,13 +9,13 @@ namespace mlx::core { void unary_op_gpu( const std::vector& inputs, array& out, - const std::string op, + const char* op, const Stream& s); void unary_op_gpu_inplace( const std::vector& inputs, array& out, - const std::string op, + const char* op, const Stream& s); } // namespace mlx::core diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index a491521a0..e7784e599 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -40,7 +40,7 @@ inline void debug_set_primitive_buffer_label( if (auto cbuf_label = command_buffer->label(); cbuf_label) { label << cbuf_label->utf8String(); } - primitive.print(label); + label << primitive.name(); command_buffer->setLabel(make_string(label)); #endif } diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 0cb3b5a85..91743ec04 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -107,7 +107,7 @@ Compiled::Compiled( // name and type of output os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize(); // computation performed - a.primitive().print(os); + os << a.primitive().name(); // name of inputs to the function for (auto& inp : a.inputs()) { os << namer.get_name(inp); @@ -170,11 +170,16 @@ bool Compiled::is_equivalent(const Primitive& other) const { }); } -void Compiled::print(std::ostream& os) { - os << "Compiled"; - for (auto& a : tape_) { - a.primitive().print(os); +const char* Compiled::name() const { + if (name_.empty()) { + std::ostringstream os; + os << "Compiled"; + for (auto& a : tape_) { + os << a.primitive().name(); + } + name_ = os.str(); } + return name_.c_str(); } std::vector Compiled::output_shapes(const std::vector& inputs) { diff --git a/mlx/distributed/primitives.h b/mlx/distributed/primitives.h index 7320e6cb6..7ad00a0d6 100644 --- a/mlx/distributed/primitives.h +++ b/mlx/distributed/primitives.h @@ -45,27 +45,22 @@ class AllReduce : public DistPrimitive { const std::vector& argnums, const std::vector& outputs) override; - void print(std::ostream& os) override { + const char* name() const override { switch (reduce_type_) { case And: - os << "And"; + return "And AllReduce"; case Or: - os << "And"; - break; + return "Or AllReduce"; case Sum: - os << "Sum"; - break; + return "Sum AllReduce"; case Prod: - os << "Prod"; - break; + return "Prod AllReduce"; case Min: - os << "Min"; - break; + return "Min AllReduce"; case Max: - os << "Max"; - break; + return "Max AllReduce"; } - os << " AllReduce"; + return ""; } private: @@ -94,7 +89,7 @@ class AllGather : public DistPrimitive { const std::vector& argnums, const std::vector& outputs) override; - DEFINE_PRINT(AllGather); + DEFINE_NAME(AllGather); }; class Send : public DistPrimitive { @@ -110,7 +105,7 @@ class Send : public DistPrimitive { const std::vector& inputs, const std::vector& axes) override; - DEFINE_PRINT(Send); + DEFINE_NAME(Send); private: int dst_; @@ -126,7 +121,7 @@ class Recv : public DistPrimitive { void eval_gpu(const std::vector& inputs, std::vector& outputs) override; - DEFINE_PRINT(Recv); + DEFINE_NAME(Recv); private: int src_; diff --git a/mlx/export.cpp b/mlx/export.cpp index 552c35cfb..8eb385bb1 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -354,9 +354,7 @@ struct PrimitiveFactory { void save(Writer& os, const std::shared_ptr& p) { serialize(os, p->stream()); - std::ostringstream pout; - p->print(pout); - auto name = pout.str(); + std::string name = p->name(); name = name.substr(0, name.find(' ')); if (auto it = name_remap.find(name); it != name_remap.end()) { name = it->second; diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 51050ea50..52135adad 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -58,7 +58,7 @@ class RMSNorm : public Custom { const std::vector& argnums, const std::vector& outputs) override; - DEFINE_PRINT(RMSNorm) + DEFINE_NAME(RMSNorm) bool is_equivalent(const Primitive& other) const override; DEFINE_INPUT_OUTPUT_SHAPE() @@ -85,7 +85,7 @@ class RMSNormVJP : public Custom { void eval_gpu(const std::vector& inputs, std::vector& outputs) override; - DEFINE_PRINT(RMSNormVJP) + DEFINE_NAME(RMSNormVJP) bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_pair(nullptr, eps_); @@ -118,7 +118,7 @@ class LayerNorm : public Custom { const std::vector& argnums, const std::vector& outputs) override; - DEFINE_PRINT(LayerNorm) + DEFINE_NAME(LayerNorm) bool is_equivalent(const Primitive& other) const override; DEFINE_INPUT_OUTPUT_SHAPE() auto state() const { @@ -144,7 +144,7 @@ class LayerNormVJP : public Custom { void eval_gpu(const std::vector& inputs, std::vector& outputs) override; - DEFINE_PRINT(LayerNormVJP) + DEFINE_NAME(LayerNormVJP) bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_pair(nullptr, eps_); @@ -186,7 +186,7 @@ class RoPE : public Custom { const std::vector& argnums, const std::vector& outputs) override; - DEFINE_PRINT(RoPE) + DEFINE_NAME(RoPE) bool is_equivalent(const Primitive& other) const override; DEFINE_INPUT_OUTPUT_SHAPE() auto state() const { @@ -233,7 +233,7 @@ class ScaledDotProductAttention : public Custom { void eval_gpu(const std::vector& inputs, array& out); bool is_equivalent(const Primitive& other) const override; - DEFINE_PRINT(ScaledDotProductAttention); + DEFINE_NAME(ScaledDotProductAttention); DEFINE_INPUT_OUTPUT_SHAPE() auto state() const { return std::make_tuple(nullptr, scale_, do_causal_); @@ -263,7 +263,7 @@ class AffineQuantize : public Custom { void eval_gpu(const std::vector& inputs, std::vector& outputs) override; - DEFINE_PRINT(AffineQuantize); + DEFINE_NAME(AffineQuantize); bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; @@ -311,7 +311,7 @@ class CustomKernel : public Primitive { void eval_gpu(const std::vector& inputs, std::vector& outputs) override; - DEFINE_PRINT(CustomKernel); + DEFINE_NAME(CustomKernel); private: std::string source_; diff --git a/mlx/graph_utils.cpp b/mlx/graph_utils.cpp index 29373f266..854881bc9 100644 --- a/mlx/graph_utils.cpp +++ b/mlx/graph_utils.cpp @@ -93,7 +93,7 @@ void print_graph( os << "\n"; for (auto& arr : tape) { - arr.primitive().print(os); + os << arr.primitive().name(); os << " "; print_arrs(arr.inputs()); os << " -> "; @@ -143,7 +143,7 @@ void export_to_dot( os << "{ "; os << x.primitive_id(); os << " [label =\""; - x.primitive().print(os); + os << x.primitive().name(); os << "\", shape=rectangle]"; os << "; }" << std::endl; // Arrows to primitive's inputs diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index ff3208e1e..e8a9e430e 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -500,7 +500,7 @@ array cross( void validate_eig( const array& a, const StreamOrDevice& stream, - const std::string fname) { + const std::string& fname) { check_cpu_stream(stream, fname); check_float_or_complex(a.dtype(), fname); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 72affbd34..cf0e6ef0d 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -181,7 +181,7 @@ std::vector Primitive::jvp( const std::vector&) { std::ostringstream msg; msg << "[Primitive::jvp] Not implemented for "; - print(msg); + msg << name(); msg << "."; throw std::invalid_argument(msg.str()); } @@ -193,7 +193,7 @@ std::vector Primitive::vjp( const std::vector&) { std::ostringstream msg; msg << "[Primitive::vjp] Not implemented for "; - print(msg); + msg << name(); msg << "."; throw std::invalid_argument(msg.str()); } @@ -203,7 +203,7 @@ std::pair, std::vector> Primitive::vmap( const std::vector&) { std::ostringstream msg; msg << "[Primitive::vmap] Not implemented for "; - print(msg); + msg << name(); msg << "."; throw std::invalid_argument(msg.str()); } @@ -211,7 +211,7 @@ std::pair, std::vector> Primitive::vmap( std::vector Primitive::output_shapes(const std::vector&) { std::ostringstream msg; msg << "[Primitive::output_shapes] "; - this->print(msg); + msg << name(); msg << " cannot infer output shapes."; throw std::invalid_argument(msg.str()); } @@ -743,26 +743,6 @@ bool BitwiseBinary::is_equivalent(const Primitive& other) const { return op_ == a_other.op_; } -void BitwiseBinary::print(std::ostream& os) { - switch (op_) { - case BitwiseBinary::And: - os << "BitwiseAnd"; - break; - case BitwiseBinary::Or: - os << "BitwiseOr"; - break; - case BitwiseBinary::Xor: - os << "BitwiseXor"; - break; - case BitwiseBinary::LeftShift: - os << "LeftShift"; - break; - case BitwiseBinary::RightShift: - os << "RightShift"; - break; - } -} - std::pair, std::vector> BitwiseBinary::vmap( const std::vector& inputs, const std::vector& axes) { @@ -5375,8 +5355,13 @@ std::pair, std::vector> View::vmap( return {{view(inputs[0], dtype_, stream())}, axes}; } -void View::print(std::ostream& os) { - os << "View " << dtype_; +const char* View::name() const { + if (name_.empty()) { + std::ostringstream os; + os << "View " << dtype_; + name_ = os.str(); + } + return name_.c_str(); } bool View::is_equivalent(const Primitive& other) const { diff --git a/mlx/primitives.h b/mlx/primitives.h index 3d3202aaa..d482a1bf9 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -26,9 +26,9 @@ const std::vector& argnums, \ const std::vector& outputs) override; -#define DEFINE_PRINT(PRIMITIVE) \ - void print(std::ostream& os) override { \ - os << #PRIMITIVE; \ +#define DEFINE_NAME(PRIMITIVE) \ + const char* name() const override { \ + return #PRIMITIVE; \ } #define DEFINE_DEFAULT_IS_EQUIVALENT() \ @@ -100,8 +100,8 @@ class Primitive { const std::vector& inputs, const std::vector& axes); - /** Print the primitive. */ - virtual void print(std::ostream& os) = 0; + /** Get the name of primitive. */ + virtual const char* name() const = 0; /** Equivalence check defaults to false unless overridden by the primitive */ virtual bool is_equivalent(const Primitive& other) const { @@ -160,7 +160,7 @@ class Abs : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Abs) + DEFINE_NAME(Abs) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -174,7 +174,7 @@ class Add : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Add) + DEFINE_NAME(Add) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -189,7 +189,7 @@ class AddMM : public UnaryPrimitive { DEFINE_GRADS() DEFINE_VMAP() - DEFINE_PRINT(AddMM) + DEFINE_NAME(AddMM) bool is_equivalent(const Primitive& other) const override; std::pair state() const { @@ -209,7 +209,7 @@ class Arange : public UnaryPrimitive { void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - DEFINE_PRINT(Arange) + DEFINE_NAME(Arange) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; std::tuple state() const { @@ -231,7 +231,7 @@ class ArcCos : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArcCos) + DEFINE_NAME(ArcCos) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -245,7 +245,7 @@ class ArcCosh : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArcCosh) + DEFINE_NAME(ArcCosh) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -259,7 +259,7 @@ class ArcSin : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArcSin) + DEFINE_NAME(ArcSin) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -273,7 +273,7 @@ class ArcSinh : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArcSinh) + DEFINE_NAME(ArcSinh) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -287,7 +287,7 @@ class ArcTan : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArcTan) + DEFINE_NAME(ArcTan) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -301,7 +301,7 @@ class ArcTan2 : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArcTan2) + DEFINE_NAME(ArcTan2) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -315,7 +315,7 @@ class ArcTanh : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArcTanh) + DEFINE_NAME(ArcTanh) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -330,7 +330,7 @@ class ArgPartition : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArgPartition) + DEFINE_NAME(ArgPartition) DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; std::pair state() const { @@ -357,7 +357,7 @@ class ArgReduce : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArgReduce) + DEFINE_NAME(ArgReduce) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; std::pair state() const { @@ -379,7 +379,7 @@ class ArgSort : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArgSort) + DEFINE_NAME(ArgSort) DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; int state() const { @@ -400,7 +400,7 @@ class AsType : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(AsType) + DEFINE_NAME(AsType) DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; Dtype state() const { @@ -423,7 +423,7 @@ class AsStrided : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_GRADS() - DEFINE_PRINT(AsStrided) + DEFINE_NAME(AsStrided) bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_tuple(shape_, strides_, offset_); @@ -449,8 +449,24 @@ class BitwiseBinary : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() + + const char* name() const override { + switch (op_) { + case BitwiseBinary::And: + return "BitwiseAnd"; + case BitwiseBinary::Or: + return "BitwiseOr"; + case BitwiseBinary::Xor: + return "BitwiseXor"; + case BitwiseBinary::LeftShift: + return "LeftShift"; + case BitwiseBinary::RightShift: + return "RightShift"; + } + return ""; + } + bool is_equivalent(const Primitive& other) const override; - void print(std::ostream& os) override; DEFINE_INPUT_OUTPUT_SHAPE() auto state() const { return op_; @@ -468,7 +484,7 @@ class BitwiseInvert : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_VMAP() - DEFINE_PRINT(BitwiseInvert) + DEFINE_NAME(BitwiseInvert) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -487,7 +503,7 @@ class BlockMaskedMM : public UnaryPrimitive { const std::vector& argnums, const std::vector& outputs) override; - DEFINE_PRINT(BlockMaskedMM) + DEFINE_NAME(BlockMaskedMM) bool is_equivalent(const Primitive& other) const override; auto state() const { return block_size_; @@ -516,7 +532,7 @@ class GatherMM : public UnaryPrimitive { const std::vector& argnums, const std::vector& outputs) override; - DEFINE_PRINT(GatherMM) + DEFINE_NAME(GatherMM) bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_pair(left_sorted_, right_sorted_); @@ -534,7 +550,7 @@ class SegmentedMM : public UnaryPrimitive { void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - DEFINE_PRINT(SegmentedMM) + DEFINE_NAME(SegmentedMM) }; class BroadcastAxes : public UnaryPrimitive { @@ -547,7 +563,7 @@ class BroadcastAxes : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(BroadcastAxes) + DEFINE_NAME(BroadcastAxes) bool is_equivalent(const Primitive& other) const override; static Shape output_shape( const std::vector& inputs, @@ -572,7 +588,7 @@ class Broadcast : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Broadcast) + DEFINE_NAME(Broadcast) static Shape output_shape(const std::vector& inputs); std::vector output_shapes(const std::vector& inputs) override; bool is_equivalent(const Primitive& other) const override; @@ -595,7 +611,7 @@ class Ceil : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Ceil) + DEFINE_NAME(Ceil) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -625,8 +641,8 @@ class Compiled : public Primitive { DEFINE_VMAP() DEFINE_GRADS() + const char* name() const override; std::vector output_shapes(const std::vector& inputs) override; - void print(std::ostream& os) override; bool is_equivalent(const Primitive& other) const override; std::string lib_name() const { @@ -640,6 +656,7 @@ class Compiled : public Primitive { const std::unordered_set constant_ids_; const std::function is_constant_; + mutable std::string name_; std::string kernel_lib_; }; @@ -653,7 +670,7 @@ class Concatenate : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Concatenate) + DEFINE_NAME(Concatenate) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; auto state() const { @@ -672,7 +689,7 @@ class Conjugate : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_VMAP() - DEFINE_PRINT(Conjugate) + DEFINE_NAME(Conjugate) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -687,7 +704,7 @@ class Contiguous : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Contiguous) + DEFINE_NAME(Contiguous) DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; @@ -726,7 +743,7 @@ class Convolution : public UnaryPrimitive { const std::vector& outputs) override; DEFINE_VMAP() - DEFINE_PRINT(Convolution) + DEFINE_NAME(Convolution) bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_tuple( @@ -758,7 +775,7 @@ class Copy : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Copy) + DEFINE_NAME(Copy) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() @@ -775,7 +792,7 @@ class Cos : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Cos) + DEFINE_NAME(Cos) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -789,7 +806,7 @@ class Cosh : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Cosh) + DEFINE_NAME(Cosh) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -823,7 +840,7 @@ class CustomTransforms : public Primitive { DEFINE_GRADS(); DEFINE_VMAP(); - DEFINE_PRINT(CustomTransforms); + DEFINE_NAME(CustomTransforms); private: void eval(const std::vector& inputs, std::vector& outputs); @@ -861,7 +878,7 @@ class Depends : public Primitive { const std::vector& argnums, const std::vector& outputs) override; - DEFINE_PRINT(Depends); + DEFINE_NAME(Depends); private: void eval(const std::vector& inputs, std::vector& outputs); @@ -876,7 +893,7 @@ class Divide : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Divide) + DEFINE_NAME(Divide) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -892,7 +909,7 @@ class DivMod : public Primitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(DivMod) + DEFINE_NAME(DivMod) DEFINE_DEFAULT_IS_EQUIVALENT() std::vector output_shapes(const std::vector& inputs) override { return std::vector{inputs[0].shape(), inputs[0].shape()}; @@ -908,7 +925,7 @@ class Select : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Select) + DEFINE_NAME(Select) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -922,7 +939,7 @@ class Remainder : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Remainder) + DEFINE_NAME(Remainder) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -940,11 +957,11 @@ class Equal : public UnaryPrimitive { DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - void print(std::ostream& os) override { + const char* name() const override { if (equal_nan_) { - os << "NaNEqual"; + return "NaNEqual"; } else { - os << "Equal"; + return "Equal"; } } auto state() const { @@ -964,7 +981,7 @@ class Erf : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Erf) + DEFINE_NAME(Erf) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -978,7 +995,7 @@ class ErfInv : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ErfInv) + DEFINE_NAME(ErfInv) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -992,7 +1009,7 @@ class Exp : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Exp) + DEFINE_NAME(Exp) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1006,7 +1023,7 @@ class Expm1 : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Expm1) + DEFINE_NAME(Expm1) DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1020,7 +1037,7 @@ class ExpandDims : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ExpandDims) + DEFINE_NAME(ExpandDims) std::vector output_shapes(const std::vector& inputs) override; bool is_equivalent(const Primitive& other) const override; @@ -1049,7 +1066,7 @@ class FFT : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(FFT) + DEFINE_NAME(FFT) bool is_equivalent(const Primitive& other) const override; auto state() const { @@ -1072,7 +1089,7 @@ class Flatten : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Flatten) + DEFINE_NAME(Flatten) std::vector output_shapes(const std::vector& inputs) override; bool is_equivalent(const Primitive& other) const override; @@ -1096,7 +1113,7 @@ class Floor : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Floor) + DEFINE_NAME(Floor) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1110,7 +1127,7 @@ class Full : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Full) + DEFINE_NAME(Full) DEFINE_DEFAULT_IS_EQUIVALENT() }; @@ -1126,7 +1143,7 @@ class Gather : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Gather) + DEFINE_NAME(Gather) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; std::pair, std::vector> state() const { @@ -1148,7 +1165,7 @@ class GatherAxis : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(GatherAxis) + DEFINE_NAME(GatherAxis) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; auto state() const { @@ -1168,7 +1185,7 @@ class Greater : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Greater) + DEFINE_NAME(Greater) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1182,7 +1199,7 @@ class GreaterEqual : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(GreaterEqual) + DEFINE_NAME(GreaterEqual) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1197,7 +1214,7 @@ class Hadamard : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Hadamard) + DEFINE_NAME(Hadamard) DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; @@ -1218,7 +1235,7 @@ class Imag : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Imag) + DEFINE_NAME(Imag) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1232,7 +1249,7 @@ class Less : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Less) + DEFINE_NAME(Less) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1246,7 +1263,7 @@ class LessEqual : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(LessEqual) + DEFINE_NAME(LessEqual) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1266,7 +1283,7 @@ class Load : public UnaryPrimitive { void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - DEFINE_PRINT(Load) + DEFINE_NAME(Load) private: std::shared_ptr reader_; @@ -1293,18 +1310,16 @@ class Log : public UnaryPrimitive { return base_; }; - void print(std::ostream& os) override { + const char* name() const override { switch (base_) { case e: - os << "Log"; - break; + return "Log"; case two: - os << "Log2"; - break; + return "Log2"; case ten: - os << "Log10"; - break; + return "Log10"; } + return ""; } private: @@ -1320,7 +1335,7 @@ class Log1p : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Log1p) + DEFINE_NAME(Log1p) DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1333,7 +1348,7 @@ class LogicalNot : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(LogicalNot) + DEFINE_NAME(LogicalNot) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1347,7 +1362,7 @@ class LogicalAnd : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(LogicalAnd) + DEFINE_NAME(LogicalAnd) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1361,7 +1376,7 @@ class LogicalOr : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(LogicalOr) + DEFINE_NAME(LogicalOr) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1375,7 +1390,7 @@ class LogAddExp : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(LogAddExp) + DEFINE_NAME(LogAddExp) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1389,7 +1404,7 @@ class LogSumExp : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(LogSumExp) + DEFINE_NAME(LogSumExp) DEFINE_DEFAULT_IS_EQUIVALENT() std::vector output_shapes(const std::vector& inputs) override; }; @@ -1403,7 +1418,7 @@ class Matmul : public UnaryPrimitive { DEFINE_GRADS() DEFINE_VMAP() - DEFINE_PRINT(Matmul) + DEFINE_NAME(Matmul) DEFINE_DEFAULT_IS_EQUIVALENT() std::vector output_shapes(const std::vector& inputs) override; }; @@ -1417,7 +1432,7 @@ class Maximum : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Maximum) + DEFINE_NAME(Maximum) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1431,7 +1446,7 @@ class Minimum : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Minimum) + DEFINE_NAME(Minimum) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1445,7 +1460,7 @@ class Multiply : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Multiply) + DEFINE_NAME(Multiply) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1459,7 +1474,7 @@ class Negative : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Negative) + DEFINE_NAME(Negative) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1473,7 +1488,7 @@ class NotEqual : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(NotEqual) + DEFINE_NAME(NotEqual) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1494,7 +1509,7 @@ class NumberOfElements : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_VMAP() - DEFINE_PRINT(NumberOfElements) + DEFINE_NAME(NumberOfElements) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override { return {{}}; @@ -1528,7 +1543,7 @@ class Pad : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Pad) + DEFINE_NAME(Pad) bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_tuple(axes_, low_pad_size_, high_pad_size_); @@ -1550,7 +1565,7 @@ class Partition : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Partition) + DEFINE_NAME(Partition) DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; auto state() const { @@ -1571,7 +1586,7 @@ class Power : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Power) + DEFINE_NAME(Power) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1593,7 +1608,7 @@ class QuantizedMatmul : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(QuantizedMatmul) + DEFINE_NAME(QuantizedMatmul) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; auto state() const { @@ -1627,7 +1642,7 @@ class GatherQMM : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(GatherQMM) + DEFINE_NAME(GatherQMM) bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_tuple( @@ -1651,7 +1666,7 @@ class RandomBits : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_VMAP() - DEFINE_PRINT(RandomBits) + DEFINE_NAME(RandomBits) bool is_equivalent(const Primitive& other) const override; std::pair, int> state() const { return {shape_, width_}; @@ -1671,7 +1686,7 @@ class Real : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Real) + DEFINE_NAME(Real) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1686,7 +1701,7 @@ class Reshape : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Reshape) + DEFINE_NAME(Reshape) bool is_equivalent(const Primitive& other) const override; std::vector state() const { return shape_; @@ -1721,28 +1736,24 @@ class Reduce : public UnaryPrimitive { std::vector output_shapes(const std::vector& inputs) override; - void print(std::ostream& os) override { + const char* name() const override { switch (reduce_type_) { case And: - os << "And"; - break; + return "And"; case Or: - os << "Or"; - break; + return "Or"; case Sum: - os << "Sum"; - break; + return "Sum"; case Prod: - os << "Prod"; - break; + return "Prod"; case Min: - os << "Min"; - break; + return "Min"; case Max: - os << "Max"; - break; + return "Max"; } + return ""; } + bool is_equivalent(const Primitive& other) const override; std::pair> state() const { return {reduce_type_, axes_}; @@ -1762,7 +1773,7 @@ class Round : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Round) + DEFINE_NAME(Round) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1789,26 +1800,22 @@ class Scan : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS(); - void print(std::ostream& os) override { - os << "Cum"; + const char* name() const override { switch (reduce_type_) { case Sum: - os << "Sum"; - break; + return "CumSum"; case Prod: - os << "Prod"; - break; + return "CumProd"; case Min: - os << "Min"; - break; + return "CumMin"; case Max: - os << "Max"; - break; + return "CumMax"; case LogAddExp: - os << "Logaddexp"; - break; + return "CumLogAddExp"; } + return ""; } + bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_tuple(reduce_type_, axis_, reverse_, inclusive_); @@ -1837,25 +1844,22 @@ class Scatter : public UnaryPrimitive { DEFINE_VMAP(); DEFINE_GRADS(); - void print(std::ostream& os) override { - os << "Scatter"; + const char* name() const override { switch (reduce_type_) { case Sum: - os << " Sum"; - break; + return "ScatterSum"; case Prod: - os << " Prod"; - break; + return "ScatterProd"; case Min: - os << " Min"; - break; + return "ScatterMin"; case Max: - os << " Max"; - break; + return "ScatterMax"; case None: - break; + return "Scatter"; } + return ""; } + bool is_equivalent(const Primitive& other) const override; std::pair> state() const { return {reduce_type_, axes_}; @@ -1879,15 +1883,14 @@ class ScatterAxis : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - void print(std::ostream& os) override { - os << "ScatterAxis"; + const char* name() const override { switch (reduce_type_) { case Sum: - os << " Sum"; - break; + return "ScatterAxisSum"; case None: - break; + return "ScatterAxis"; } + return ""; } bool is_equivalent(const Primitive& other) const override; @@ -1910,7 +1913,7 @@ class Sigmoid : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Sigmoid) + DEFINE_NAME(Sigmoid) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1924,7 +1927,7 @@ class Sign : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Sign) + DEFINE_NAME(Sign) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1938,7 +1941,7 @@ class Sin : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Sin) + DEFINE_NAME(Sin) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1952,7 +1955,7 @@ class Sinh : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Sinh) + DEFINE_NAME(Sinh) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1974,7 +1977,7 @@ class Slice : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Slice) + DEFINE_NAME(Slice) bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_tuple(start_indices_, end_indices_, strides_); @@ -2003,7 +2006,7 @@ class SliceUpdate : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(SliceUpdate) + DEFINE_NAME(SliceUpdate) bool is_equivalent(const Primitive& other) const override; DEFINE_INPUT_OUTPUT_SHAPE() auto state() const { @@ -2028,7 +2031,7 @@ class DynamicSlice : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(DynamicSlice) + DEFINE_NAME(DynamicSlice) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; auto state() const { @@ -2050,7 +2053,7 @@ class DynamicSliceUpdate : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(DynamicSliceUpdate) + DEFINE_NAME(DynamicSliceUpdate) bool is_equivalent(const Primitive& other) const override; DEFINE_INPUT_OUTPUT_SHAPE() auto state() const { @@ -2071,7 +2074,7 @@ class Softmax : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Softmax) + DEFINE_NAME(Softmax) DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; @@ -2093,7 +2096,7 @@ class Sort : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Sort) + DEFINE_NAME(Sort) DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; auto state() const { @@ -2116,7 +2119,7 @@ class Split : public Primitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Split) + DEFINE_NAME(Split) bool is_equivalent(const Primitive& other) const override; std::pair, int> state() const { return {indices_, axis_}; @@ -2138,7 +2141,7 @@ class Square : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Square) + DEFINE_NAME(Square) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -2159,11 +2162,11 @@ class Sqrt : public UnaryPrimitive { return recip_; } - void print(std::ostream& os) override { + const char* name() const override { if (recip_) { - os << "Rsqrt"; + return "Rsqrt"; } else { - os << "Sqrt"; + return "Sqrt"; } } @@ -2179,7 +2182,7 @@ class StopGradient : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_VMAP() - DEFINE_PRINT(StopGradient) + DEFINE_NAME(StopGradient) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() @@ -2196,7 +2199,7 @@ class Subtract : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Subtract) + DEFINE_NAME(Subtract) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -2211,7 +2214,7 @@ class Squeeze : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Squeeze) + DEFINE_NAME(Squeeze) std::vector output_shapes(const std::vector& inputs) override; bool is_equivalent(const Primitive& other) const override; @@ -2235,7 +2238,7 @@ class Tan : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Tan) + DEFINE_NAME(Tan) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -2249,7 +2252,7 @@ class Tanh : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Tanh) + DEFINE_NAME(Tanh) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -2264,7 +2267,7 @@ class Unflatten : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Unflatten) + DEFINE_NAME(Unflatten) std::vector output_shapes(const std::vector& inputs) override; bool is_equivalent(const Primitive& other) const override; @@ -2288,7 +2291,7 @@ class View : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_VMAP() - void print(std::ostream& os) override; + const char* name() const override; bool is_equivalent(const Primitive& other) const override; auto state() const { return dtype_; @@ -2296,6 +2299,7 @@ class View : public UnaryPrimitive { private: Dtype dtype_; + mutable std::string name_; }; class Transpose : public UnaryPrimitive { @@ -2308,7 +2312,7 @@ class Transpose : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Transpose) + DEFINE_NAME(Transpose) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; std::vector state() const { @@ -2331,7 +2335,7 @@ class QRF : public Primitive { void eval_gpu(const std::vector& inputs, std::vector& outputs) override; - DEFINE_PRINT(QRF) + DEFINE_NAME(QRF) }; /* SVD primitive. */ @@ -2346,7 +2350,7 @@ class SVD : public Primitive { override; DEFINE_VMAP() - DEFINE_PRINT(SVD) + DEFINE_NAME(SVD) auto state() const { return compute_uv_; } @@ -2365,7 +2369,7 @@ class Inverse : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& output) override; DEFINE_VMAP() - DEFINE_PRINT(Inverse) + DEFINE_NAME(Inverse) auto state() const { return std::make_pair(tri_, upper_); } @@ -2387,7 +2391,7 @@ class Cholesky : public UnaryPrimitive { } DEFINE_VMAP() - DEFINE_PRINT(Cholesky) + DEFINE_NAME(Cholesky) private: bool upper_; @@ -2403,7 +2407,7 @@ class Eig : public Primitive { override; DEFINE_VMAP() - DEFINE_PRINT(Eig) + DEFINE_NAME(Eig) std::vector output_shapes(const std::vector& inputs) override; @@ -2428,7 +2432,7 @@ class Eigh : public Primitive { override; DEFINE_VMAP() - DEFINE_PRINT(Eigh) + DEFINE_NAME(Eigh) std::vector output_shapes(const std::vector& inputs) override; @@ -2451,7 +2455,7 @@ class LUF : public Primitive { void eval_gpu(const std::vector& inputs, std::vector& outputs) override; - DEFINE_PRINT(LUF) + DEFINE_NAME(LUF) }; } // namespace mlx::core diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 2d9942eda..d9e227ea3 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -33,7 +33,7 @@ class Synchronizer : public Primitive { void eval_cpu(const std::vector&, std::vector&) override {} void eval_gpu(const std::vector&, std::vector&) override {} - DEFINE_PRINT(Synchronize); + DEFINE_NAME(Synchronize); }; // Initialize the static tracing members from transforms_impl.h diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index cc8e79db6..634abaef4 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -514,7 +514,7 @@ void init_linalg(nb::module_& parent_module) { )pbdoc"); m.def( "eigh", - [](const mx::array& a, const std::string UPLO, mx::StreamOrDevice s) { + [](const mx::array& a, const std::string& UPLO, mx::StreamOrDevice s) { auto result = mx::linalg::eigh(a, UPLO, s); return nb::make_tuple(result.first, result.second); }, diff --git a/python/src/metal.cpp b/python/src/metal.cpp index 54642409c..3b2f4a53a 100644 --- a/python/src/metal.cpp +++ b/python/src/metal.cpp @@ -14,7 +14,7 @@ namespace mx = mlx::core; namespace nb = nanobind; using namespace nb::literals; -bool DEPRECATE(const std::string& old_fn, const std::string new_fn) { +bool DEPRECATE(const char* old_fn, const char* new_fn) { std::cerr << old_fn << " is deprecated and will be removed in a future " << "version. Use " << new_fn << " instead." << std::endl; return true; diff --git a/python/src/ops.cpp b/python/src/ops.cpp index d047f64cb..9703bbd2d 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3076,7 +3076,7 @@ void init_ops(nb::module_& m) { std::tuple, std::pair, std::vector>>& pad_width, - const std::string mode, + const std::string& mode, const ScalarOrArray& constant_value, mx::StreamOrDevice s) { if (auto pv = std::get_if(&pad_width); pv) {