shapeless compile in docs and partially shapeless reshape (#1742)

This commit is contained in:
Awni Hannun
2025-01-02 16:24:42 -08:00
committed by GitHub
parent a64a8dfe45
commit ae69cb15e9
6 changed files with 137 additions and 36 deletions

View File

@@ -363,41 +363,12 @@ array reshape(const array& a, Shape shape, StreamOrDevice s /* = {} */) {
if (a.shape() == shape) {
return a;
}
size_t size = 1;
int infer_idx = -1;
for (int i = 0; i < shape.size(); ++i) {
if (shape[i] == -1) {
if (infer_idx >= 0) {
throw std::invalid_argument(
"[reshape] Reshape can only infer one dimension.");
}
infer_idx = i;
} else {
size *= shape[i];
}
}
// Infer the shape
if (size > 0) {
if (infer_idx >= 0) {
shape[infer_idx] = a.size() / size;
size *= shape[infer_idx];
}
} else if (infer_idx >= 0) {
throw std::invalid_argument(
"[reshape] Cannot infer the shape of an empty array");
}
// Check that the reshaping is valid
if (a.size() != size) {
std::ostringstream msg;
msg << "[reshape] Cannot reshape array of size " << a.size()
<< " into shape " << shape << ".";
throw std::invalid_argument(msg.str());
}
auto p = std::make_shared<Reshape>(to_stream(s), shape);
return array(std::move(shape), a.dtype(), std::move(p), {a});
auto out_shape = Reshape::output_shape(a, shape);
return array(
std::move(out_shape),
a.dtype(),
std::make_shared<Reshape>(to_stream(s), std::move(shape)),
{a});
}
array unflatten(

View File

@@ -3021,6 +3021,44 @@ bool Reshape::is_equivalent(const Primitive& other) const {
return shape_ == r_other.shape_;
}
Shape Reshape::output_shape(const array& input, Shape shape) {
size_t size = 1;
int infer_idx = -1;
for (int i = 0; i < shape.size(); ++i) {
if (shape[i] == -1) {
if (infer_idx >= 0) {
throw std::invalid_argument(
"[reshape] Reshape can only infer one dimension.");
}
infer_idx = i;
} else {
size *= shape[i];
}
}
// Infer the shape
if (size > 0 && infer_idx >= 0) {
shape[infer_idx] = input.size() / size;
size *= shape[infer_idx];
} else if (infer_idx >= 0) {
throw std::invalid_argument(
"[reshape] Cannot infer the shape of an empty array");
}
// Check that the reshaping is valid
if (input.size() != size) {
std::ostringstream msg;
msg << "[reshape] Cannot reshape array of size " << input.size()
<< " into shape " << shape << ".";
throw std::invalid_argument(msg.str());
}
return shape;
}
std::vector<Shape> Reshape::output_shapes(const std::vector<array>& inputs) {
return {output_shape(inputs[0], shape_)};
}
std::vector<array> Reduce::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,

View File

@@ -1746,6 +1746,8 @@ class Reshape : public UnaryPrimitive {
std::vector<int> state() const {
return shape_;
};
static Shape output_shape(const array& input, Shape shape);
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
private:
Shape shape_;