Implement vjps for some primitives in the fast namespace (#883)

* Implement rope vjp in terms of rope
* RMSNormVJP primitive and kernel
* Add LayerNormVJP primitive and kernel
This commit is contained in:
Angelos Katharopoulos
2024-03-26 16:35:34 -07:00
committed by GitHub
parent a789685c63
commit 29221fa238
14 changed files with 1383 additions and 110 deletions

View File

@@ -6,8 +6,6 @@
namespace mlx::core {
namespace {
enum ReductionOpType {
// Self-explanatory. Read everything and produce 1 output.
ContiguousAllReduce,
@@ -38,6 +36,21 @@ enum ReductionOpType {
GeneralReduce
};
struct ReductionPlan {
ReductionOpType type;
std::vector<int> shape;
std::vector<size_t> strides;
ReductionPlan(
ReductionOpType type_,
std::vector<int> shape_,
std::vector<size_t> strides_)
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
ReductionPlan(ReductionOpType type_) : type(type_) {}
};
namespace {
// Helper for the ndimensional strided loop
// Should this be in utils?
inline void nd_loop(
@@ -110,19 +123,6 @@ struct DefaultContiguousReduce {
}
};
struct ReductionPlan {
ReductionOpType type;
std::vector<int> shape;
std::vector<size_t> strides;
ReductionPlan(
ReductionOpType type_,
std::vector<int> shape_,
std::vector<size_t> strides_)
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
ReductionPlan(ReductionOpType type_) : type(type_) {}
};
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
// The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() &&