mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add Primitive::name and remove Primitive::print (#2365)
This commit is contained in:
@@ -45,27 +45,22 @@ class AllReduce : public DistPrimitive {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
void print(std::ostream& os) override {
|
||||
const char* name() const override {
|
||||
switch (reduce_type_) {
|
||||
case And:
|
||||
os << "And";
|
||||
return "And AllReduce";
|
||||
case Or:
|
||||
os << "And";
|
||||
break;
|
||||
return "Or AllReduce";
|
||||
case Sum:
|
||||
os << "Sum";
|
||||
break;
|
||||
return "Sum AllReduce";
|
||||
case Prod:
|
||||
os << "Prod";
|
||||
break;
|
||||
return "Prod AllReduce";
|
||||
case Min:
|
||||
os << "Min";
|
||||
break;
|
||||
return "Min AllReduce";
|
||||
case Max:
|
||||
os << "Max";
|
||||
break;
|
||||
return "Max AllReduce";
|
||||
}
|
||||
os << " AllReduce";
|
||||
return "<unknwon AllReduce>";
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -94,7 +89,7 @@ class AllGather : public DistPrimitive {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_PRINT(AllGather);
|
||||
DEFINE_NAME(AllGather);
|
||||
};
|
||||
|
||||
class Send : public DistPrimitive {
|
||||
@@ -110,7 +105,7 @@ class Send : public DistPrimitive {
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
DEFINE_PRINT(Send);
|
||||
DEFINE_NAME(Send);
|
||||
|
||||
private:
|
||||
int dst_;
|
||||
@@ -126,7 +121,7 @@ class Recv : public DistPrimitive {
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(Recv);
|
||||
DEFINE_NAME(Recv);
|
||||
|
||||
private:
|
||||
int src_;
|
||||
|
||||
Reference in New Issue
Block a user