This commit is contained in:
Awni Hannun 2024-12-06 16:50:08 -08:00
parent ee59d50293
commit 2b9c24c517
4 changed files with 130 additions and 30 deletions

View File

@ -419,7 +419,7 @@ array dynamic_reshape(
// - At most a.ndim() unique letters // - At most a.ndim() unique letters
// - Only valid characters in string (alphabet, integer, *, /) // - Only valid characters in string (alphabet, integer, *, /)
bool infer_dim = false; bool infer_dim = false;
std::unordered_set<char> dims; std::unordered_map<char, int> char_to_dim;
for (auto& e : expressions) { for (auto& e : expressions) {
if (auto pv = std::get_if<int>(&e); pv) { if (auto pv = std::get_if<int>(&e); pv) {
if (*pv == -1) { if (*pv == -1) {
@ -435,7 +435,7 @@ array dynamic_reshape(
for (auto c : s) { for (auto c : s) {
if (isalpha(c)) { if (isalpha(c)) {
has_alpha = true; has_alpha = true;
dims.insert(c); char_to_dim.insert({c, char_to_dim.size()});
} else if (!isdigit(c) && c != '*' && c != '/') { } else if (!isdigit(c) && c != '*' && c != '/') {
std::ostringstream msg; std::ostringstream msg;
msg << "[dynamic_reshape] Invalid character in string expression \"" msg << "[dynamic_reshape] Invalid character in string expression \""
@ -449,20 +449,29 @@ array dynamic_reshape(
<< "one alphabetic character but got: \"" << s << "\"."; << "one alphabetic character but got: \"" << s << "\".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (!isdigit(s[0]) && !isalpha(s[0]) && !isdigit(s.back()) &&
!isalpha(s.back())) {
std::ostringstream msg;
msg << "[dynamic_reshape] String expression must start and end with "
<< "integer or letter but got: \"" << s << "\".";
throw std::invalid_argument(msg.str());
}
} }
} }
if (dims.size() >= a.ndim()) { if (char_to_dim.size() > a.ndim()) {
std::ostringstream msg; std::ostringstream msg;
msg << "[dynamic_reshape] Expressions contain " << dims.size() msg << "[dynamic_reshape] Expressions contain " << char_to_dim.size()
<< " abstract dimensions for array with only " << a.ndim() << " abstract dimensions for array with only " << a.ndim()
<< " dimensions."; << " dimensions.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
auto output_shape = Shape{}; // Reshape::shape_from_expression(a, expression); auto output_shape =
Reshape::shape_from_expressions(expressions, char_to_dim, a);
return array( return array(
std::move(output_shape), std::move(output_shape),
a.dtype(), a.dtype(),
std::make_shared<Reshape>(to_stream(s), std::move(expressions)), std::make_shared<Reshape>(
to_stream(s), std::move(expressions), std::move(char_to_dim)),
{a}); {a});
} }

View File

@ -2858,24 +2858,21 @@ std::vector<array> Reshape::jvp(
bool Reshape::is_equivalent(const Primitive& other) const { bool Reshape::is_equivalent(const Primitive& other) const {
const Reshape& r_other = static_cast<const Reshape&>(other); const Reshape& r_other = static_cast<const Reshape&>(other);
if (!expression_.empty()) { if (!expressions_.empty()) {
return expression_ == r_other.expression_; return expressions_ == r_other.expressions_;
} }
return shape_ == r_other.shape_; return shape_ == r_other.shape_;
} }
std::vector<Shape> Reshape::output_shapes(const std::vector<array>& inputs) { Shape Reshape::shape_from_expressions(
// Only allowed to dynamically reshape when the shape is {} const std::vector<std::variant<int, std::string>>& expressions,
if (expression_.empty() && !shape_.empty()) { const std::unordered_map<char, int>& char_to_dim,
throw std::invalid_argument( const array& in) {
"[Reshape::output_shapes] Unable to infer output shape."); Shape output_shape(expressions.size());
}
auto& in = inputs[0];
Shape output_shape(expression_.size());
int dim_to_infer = -1; int dim_to_infer = -1;
for (int i = 0, j = 0; i < expression_.size(); ++i) { uint64_t size = 1;
auto& e = expression_[i]; for (int i = 0; i < expressions.size(); ++i) {
auto& e = expressions[i];
if (auto pv = std::get_if<int>(&e); pv) { if (auto pv = std::get_if<int>(&e); pv) {
if (*pv == -1) { if (*pv == -1) {
dim_to_infer = i; dim_to_infer = i;
@ -2885,20 +2882,66 @@ std::vector<Shape> Reshape::output_shapes(const std::vector<array>& inputs) {
} }
} else { } else {
auto& s = std::get<std::string>(e); auto& s = std::get<std::string>(e);
output_shape[i] = in.shape()[j++]; if (s.size() == 1) {
output_shape[i] = in.shape()[char_to_dim.at(s[0])];
} else {
int d;
size_t loc = 0;
char op = 0;
while (loc < s.size()) {
int res;
if (std::isdigit(s[loc])) {
char* p;
res = std::strtol(s.c_str() + loc, &p, 10);
loc = (p - s.c_str());
} else if (std::isalpha(s[loc])) {
res = in.shape()[char_to_dim.at(s[loc++])];
} else if (s[loc] == '*' || s[loc] == '/') {
op = s[loc++];
continue;
}
if (op == '*') {
d *= res;
} else if (op == '/') {
d /= res;
} else {
d = res;
}
}
output_shape[i] = d;
}
} }
size *= output_shape[i];
} }
if (dim_to_infer >= 0) { if (dim_to_infer >= 0) {
uint64_t output_size = 1; if (size == 0) {
for (int i = 0; i < output_shape.size(); ++i) { throw std::invalid_argument(
if (i != dim_to_infer) { "[dynamic_reshape] Cannot infer the shape of an empty array.");
output_size *= output_shape[i];
}
} }
output_shape[dim_to_infer] = in.size() / output_size; auto d = in.size() / size;
output_shape[dim_to_infer] = d;
size *= d;
} }
return {std::move(output_shape)};
if (in.size() != size) {
std::ostringstream msg;
msg << "[dynamic_reshape] Cannot reshape array of size " << in.size()
<< " into shape " << output_shape << ".";
throw std::invalid_argument(msg.str());
}
return output_shape;
}
std::vector<Shape> Reshape::output_shapes(const std::vector<array>& inputs) {
// Only allowed to dynamically reshape when the shape is {}
if (expressions_.empty() && !shape_.empty()) {
throw std::invalid_argument(
"[Reshape::output_shapes] Unable to infer output shape.");
}
return {shape_from_expressions(expressions_, char_to_dim_, inputs[0])};
} }
std::vector<array> Reduce::vjp( std::vector<array> Reduce::vjp(

View File

@ -1611,8 +1611,11 @@ class Reshape : public UnaryPrimitive {
explicit Reshape( explicit Reshape(
Stream stream, Stream stream,
std::vector<std::variant<int, std::string>> expression) std::vector<std::variant<int, std::string>> expressions,
: UnaryPrimitive(stream), expression_(std::move(expression)) {} std::unordered_map<char, int> char_to_dim)
: UnaryPrimitive(stream),
expressions_(std::move(expressions)),
char_to_dim_(std::move(char_to_dim)) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1623,9 +1626,15 @@ class Reshape : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override; std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
static Shape shape_from_expressions(
const std::vector<std::variant<int, std::string>>& expressions,
const std::unordered_map<char, int>& char_to_dim,
const array& in);
private: private:
Shape shape_; Shape shape_;
std::vector<std::variant<int, std::string>> expression_; std::vector<std::variant<int, std::string>> expressions_;
std::unordered_map<char, int> char_to_dim_;
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);

View File

@ -4,6 +4,7 @@
#define _USE_MATH_DEFINES #define _USE_MATH_DEFINES
#include <cmath> #include <cmath>
#include <iostream> // TODO
#include <numeric> #include <numeric>
#include "doctest/doctest.h" #include "doctest/doctest.h"
@ -3777,6 +3778,10 @@ TEST_CASE("test dynamic reshape") {
// Bad character // Bad character
CHECK_THROWS(dynamic_reshape(x, {"&", 1, 1})); CHECK_THROWS(dynamic_reshape(x, {"&", 1, 1}));
// Malformed
CHECK_THROWS(dynamic_reshape(x, {"+a", 1, 1}));
CHECK_THROWS(dynamic_reshape(x, {"a+", 1, 1}));
// No dim in string // No dim in string
CHECK_THROWS(dynamic_reshape(x, {"1", 1, 1})); CHECK_THROWS(dynamic_reshape(x, {"1", 1, 1}));
@ -3785,7 +3790,41 @@ TEST_CASE("test dynamic reshape") {
// Too many dims // Too many dims
CHECK_THROWS(dynamic_reshape(x, {"a", "b", "c", "d"})); CHECK_THROWS(dynamic_reshape(x, {"a", "b", "c", "d"}));
CHECK_THROWS(dynamic_reshape(x, {"abcd", 1, 1}));
// Too many inferred dims // Too many inferred dims
CHECK_THROWS(dynamic_reshape(x, {"a", -1, -1})); CHECK_THROWS(dynamic_reshape(x, {"a", -1, -1}));
// Bad sizes
x = zeros({2, 2, 2});
CHECK_THROWS_AS(dynamic_reshape(x, {7}), std::invalid_argument);
CHECK_THROWS_AS(dynamic_reshape(x, {-1, 7}), std::invalid_argument);
// Works with empty array
x = array({});
auto y = dynamic_reshape(x, {0, 0, 0});
CHECK_EQ(y.shape(), Shape{0, 0, 0});
CHECK_THROWS_AS(dynamic_reshape(x, {}), std::invalid_argument);
CHECK_THROWS_AS(dynamic_reshape(x, {1}), std::invalid_argument);
y = dynamic_reshape(x, {1, 5, 0});
CHECK_EQ(y.shape(), Shape{1, 5, 0});
x = array({1, 2, 3});
y = dynamic_reshape(x, {"a", 1, 1});
CHECK_EQ(y.shape(), Shape{3, 1, 1});
x = zeros({2, 2});
y = dynamic_reshape(x, {"a*b"});
CHECK_EQ(y.shape(), Shape{4});
y = dynamic_reshape(x, {"2*a"});
CHECK_EQ(y.shape(), Shape{4});
x = zeros({2, 20});
y = dynamic_reshape(x, {"a*20"});
CHECK_EQ(y.shape(), Shape{40});
x = zeros({2, 20});
y = dynamic_reshape(x, {"a", "b/10", 10});
CHECK_EQ(y.shape(), Shape{2, 2, 10});
} }