Add Primitive::name and remove Primitive::print (#2365)

This commit is contained in:
Cheng
2025-07-15 06:06:35 +09:00
committed by GitHub
parent 5201df5030
commit d34f887abc
32 changed files with 307 additions and 340 deletions

View File

@@ -45,27 +45,22 @@ class AllReduce : public DistPrimitive {
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
void print(std::ostream& os) override {
const char* name() const override {
switch (reduce_type_) {
case And:
os << "And";
return "And AllReduce";
case Or:
os << "And";
break;
return "Or AllReduce";
case Sum:
os << "Sum";
break;
return "Sum AllReduce";
case Prod:
os << "Prod";
break;
return "Prod AllReduce";
case Min:
os << "Min";
break;
return "Min AllReduce";
case Max:
os << "Max";
break;
return "Max AllReduce";
}
os << " AllReduce";
return "<unknwon AllReduce>";
}
private:
@@ -94,7 +89,7 @@ class AllGather : public DistPrimitive {
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_PRINT(AllGather);
DEFINE_NAME(AllGather);
};
class Send : public DistPrimitive {
@@ -110,7 +105,7 @@ class Send : public DistPrimitive {
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
DEFINE_PRINT(Send);
DEFINE_NAME(Send);
private:
int dst_;
@@ -126,7 +121,7 @@ class Recv : public DistPrimitive {
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(Recv);
DEFINE_NAME(Recv);
private:
int src_;