mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add Primitive::name and remove Primitive::print (#2365)
This commit is contained in:
@@ -107,7 +107,7 @@ Compiled::Compiled(
|
||||
// name and type of output
|
||||
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
|
||||
// computation performed
|
||||
a.primitive().print(os);
|
||||
os << a.primitive().name();
|
||||
// name of inputs to the function
|
||||
for (auto& inp : a.inputs()) {
|
||||
os << namer.get_name(inp);
|
||||
@@ -170,11 +170,16 @@ bool Compiled::is_equivalent(const Primitive& other) const {
|
||||
});
|
||||
}
|
||||
|
||||
void Compiled::print(std::ostream& os) {
|
||||
os << "Compiled";
|
||||
for (auto& a : tape_) {
|
||||
a.primitive().print(os);
|
||||
const char* Compiled::name() const {
|
||||
if (name_.empty()) {
|
||||
std::ostringstream os;
|
||||
os << "Compiled";
|
||||
for (auto& a : tape_) {
|
||||
os << a.primitive().name();
|
||||
}
|
||||
name_ = os.str();
|
||||
}
|
||||
return name_.c_str();
|
||||
}
|
||||
|
||||
std::vector<Shape> Compiled::output_shapes(const std::vector<array>& inputs) {
|
||||
|
||||
Reference in New Issue
Block a user