mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
@@ -116,7 +116,7 @@ std::vector<array> Primitive::jvp(
|
||||
print(msg);
|
||||
msg << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
};
|
||||
}
|
||||
|
||||
std::vector<array> Primitive::vjp(
|
||||
const std::vector<array>&,
|
||||
@@ -128,7 +128,7 @@ std::vector<array> Primitive::vjp(
|
||||
print(msg);
|
||||
msg << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
|
||||
const std::vector<array>&,
|
||||
@@ -138,7 +138,7 @@ std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
|
||||
print(msg);
|
||||
msg << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int>> Primitive::output_shapes(
|
||||
const std::vector<array>&) {
|
||||
@@ -147,7 +147,7 @@ std::vector<std::vector<int>> Primitive::output_shapes(
|
||||
this->print(msg);
|
||||
msg << " cannot infer output shapes.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
};
|
||||
}
|
||||
|
||||
std::vector<array> Abs::vjp(
|
||||
const std::vector<array>& primals,
|
||||
@@ -3430,7 +3430,7 @@ std::pair<std::vector<array>, std::vector<int>> StopGradient::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
return {{stop_gradient(inputs[0], stream())}, axes};
|
||||
};
|
||||
}
|
||||
|
||||
std::vector<array> Subtract::vjp(
|
||||
const std::vector<array>& primals,
|
||||
|
||||
Reference in New Issue
Block a user