More primitives for compiling with shapeless (#1653)

* more shapeless and more Shape

* more shape

* fix

* fix
This commit is contained in:
Awni Hannun
2024-12-06 11:29:18 -08:00
committed by GitHub
parent 95c4a2e3af
commit d0b6cb0425
5 changed files with 160 additions and 81 deletions

View File

@@ -267,6 +267,11 @@ bool Arange::is_equivalent(const Primitive& other) const {
step_ == a_other.step_);
}
std::vector<Shape> Arange::output_shapes(const std::vector<array>&) {
auto real_size = std::ceil((stop_ - start_) / step_);
return {{std::max(static_cast<int>(real_size), 0)}};
}
std::vector<array> ArcCos::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
@@ -534,11 +539,10 @@ std::pair<std::vector<array>, std::vector<int>> ArgSort::vmap(
return {{argsort(inputs[0], axis_ + axis_left, stream())}, axes};
}
std::vector<std::vector<int>> ArgReduce::output_shapes(
const std::vector<array>& inputs) {
std::vector<Shape> ArgReduce::output_shapes(const std::vector<array>& inputs) {
auto out_shape = inputs[0].shape();
out_shape[axis_] = 1;
return {out_shape};
return {std::move(out_shape)};
}
bool ArgSort::is_equivalent(const Primitive& other) const {
@@ -787,6 +791,23 @@ std::pair<std::vector<array>, std::vector<int>> Eigh::vmap(
return {outputs, std::vector<int>(outputs.size(), ax)};
}
std::vector<Shape> Eigh::output_shapes(const std::vector<array>& inputs) {
auto shape = inputs[0].shape();
shape.pop_back(); // Remove last dimension for eigenvalues
if (compute_eigenvectors_) {
return {
std::move(shape), inputs[0].shape()}; // Eigenvalues and eigenvectors
} else {
return {std::move(shape)}; // Only eigenvalues
}
}
bool Eigh::is_equivalent(const Primitive& other) const {
auto& e_other = static_cast<const Eigh&>(other);
return uplo_ == e_other.uplo_ &&
compute_eigenvectors_ == e_other.compute_eigenvectors_;
}
std::vector<array> Concatenate::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
@@ -881,6 +902,15 @@ bool Concatenate::is_equivalent(const Primitive& other) const {
return axis_ == c_other.axis_;
}
std::vector<Shape> Concatenate::output_shapes(
const std::vector<array>& inputs) {
auto shape = inputs[0].shape();
for (int i = 1; i < inputs.size(); ++i) {
shape[axis_] += inputs[i].shape(axis_);
}
return {std::move(shape)};
}
std::pair<std::vector<array>, std::vector<int>> Conjugate::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
@@ -1811,6 +1841,15 @@ bool Gather::is_equivalent(const Primitive& other) const {
return axes_ == g_other.axes_ && slice_sizes_ == g_other.slice_sizes_;
}
std::vector<Shape> Gather::output_shapes(const std::vector<array>& inputs) {
Shape out_shape;
if (inputs.size() > 1) {
out_shape = inputs[0].shape();
}
out_shape.insert(out_shape.end(), slice_sizes_.begin(), slice_sizes_.end());
return {std::move(out_shape)};
}
std::pair<std::vector<array>, std::vector<int>> Greater::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
@@ -2184,6 +2223,12 @@ std::pair<std::vector<array>, std::vector<int>> Matmul::vmap(
return {{matmul(a, b, stream())}, {0}};
}
std::vector<Shape> Matmul::output_shapes(const std::vector<array>& inputs) {
auto out_shape = inputs[0].shape();
out_shape.back() = inputs[1].shape(-1);
return {std::move(out_shape)};
}
std::vector<array> Maximum::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
@@ -2608,6 +2653,15 @@ bool QuantizedMatmul::is_equivalent(const Primitive& other) const {
transpose_ == qm_other.transpose_;
}
std::vector<Shape> QuantizedMatmul::output_shapes(
const std::vector<array>& inputs) {
auto& w = inputs[1];
int w_outer_dims = (transpose_) ? w.shape(-2) : w.shape(-1) * 32 / bits_;
auto out_shape = inputs[0].shape();
out_shape.back() = w_outer_dims;
return {std::move(out_shape)};
}
std::pair<std::vector<array>, std::vector<int>> GatherQMM::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
@@ -2937,13 +2991,12 @@ bool Reduce::is_equivalent(const Primitive& other) const {
return reduce_type_ == r_other.reduce_type_ && axes_ == r_other.axes_;
}
std::vector<std::vector<int>> Reduce::output_shapes(
const std::vector<array>& inputs) {
std::vector<int> out_shape = inputs[0].shape();
std::vector<Shape> Reduce::output_shapes(const std::vector<array>& inputs) {
auto out_shape = inputs[0].shape();
for (auto i : axes_) {
out_shape[i] = 1;
}
return {out_shape};
return {std::move(out_shape)};
}
std::vector<array> Round::vjp(
@@ -4209,6 +4262,15 @@ bool Transpose::is_equivalent(const Primitive& other) const {
return axes_ == t_other.axes_;
}
std::vector<Shape> Transpose::output_shapes(const std::vector<array>& inputs) {
auto& in = inputs[0];
Shape shape(in.ndim(), 0);
for (int i = 0; i < axes_.size(); ++i) {
shape[i] = in.shape()[axes_[i]];
}
return {std::move(shape)};
}
std::pair<std::vector<array>, std::vector<int>> NumberOfElements::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {