mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
works
This commit is contained in:
parent
ee59d50293
commit
2b9c24c517
21
mlx/ops.cpp
21
mlx/ops.cpp
@ -419,7 +419,7 @@ array dynamic_reshape(
|
||||
// - At most a.ndim() unique letters
|
||||
// - Only valid characters in string (alphabet, integer, *, /)
|
||||
bool infer_dim = false;
|
||||
std::unordered_set<char> dims;
|
||||
std::unordered_map<char, int> char_to_dim;
|
||||
for (auto& e : expressions) {
|
||||
if (auto pv = std::get_if<int>(&e); pv) {
|
||||
if (*pv == -1) {
|
||||
@ -435,7 +435,7 @@ array dynamic_reshape(
|
||||
for (auto c : s) {
|
||||
if (isalpha(c)) {
|
||||
has_alpha = true;
|
||||
dims.insert(c);
|
||||
char_to_dim.insert({c, char_to_dim.size()});
|
||||
} else if (!isdigit(c) && c != '*' && c != '/') {
|
||||
std::ostringstream msg;
|
||||
msg << "[dynamic_reshape] Invalid character in string expression \""
|
||||
@ -449,20 +449,29 @@ array dynamic_reshape(
|
||||
<< "one alphabetic character but got: \"" << s << "\".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (!isdigit(s[0]) && !isalpha(s[0]) && !isdigit(s.back()) &&
|
||||
!isalpha(s.back())) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dynamic_reshape] String expression must start and end with "
|
||||
<< "integer or letter but got: \"" << s << "\".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
if (dims.size() >= a.ndim()) {
|
||||
if (char_to_dim.size() > a.ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dynamic_reshape] Expressions contain " << dims.size()
|
||||
msg << "[dynamic_reshape] Expressions contain " << char_to_dim.size()
|
||||
<< " abstract dimensions for array with only " << a.ndim()
|
||||
<< " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
auto output_shape = Shape{}; // Reshape::shape_from_expression(a, expression);
|
||||
auto output_shape =
|
||||
Reshape::shape_from_expressions(expressions, char_to_dim, a);
|
||||
return array(
|
||||
std::move(output_shape),
|
||||
a.dtype(),
|
||||
std::make_shared<Reshape>(to_stream(s), std::move(expressions)),
|
||||
std::make_shared<Reshape>(
|
||||
to_stream(s), std::move(expressions), std::move(char_to_dim)),
|
||||
{a});
|
||||
}
|
||||
|
||||
|
@ -2858,24 +2858,21 @@ std::vector<array> Reshape::jvp(
|
||||
|
||||
bool Reshape::is_equivalent(const Primitive& other) const {
|
||||
const Reshape& r_other = static_cast<const Reshape&>(other);
|
||||
if (!expression_.empty()) {
|
||||
return expression_ == r_other.expression_;
|
||||
if (!expressions_.empty()) {
|
||||
return expressions_ == r_other.expressions_;
|
||||
}
|
||||
return shape_ == r_other.shape_;
|
||||
}
|
||||
|
||||
std::vector<Shape> Reshape::output_shapes(const std::vector<array>& inputs) {
|
||||
// Only allowed to dynamically reshape when the shape is {}
|
||||
if (expression_.empty() && !shape_.empty()) {
|
||||
throw std::invalid_argument(
|
||||
"[Reshape::output_shapes] Unable to infer output shape.");
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
Shape output_shape(expression_.size());
|
||||
Shape Reshape::shape_from_expressions(
|
||||
const std::vector<std::variant<int, std::string>>& expressions,
|
||||
const std::unordered_map<char, int>& char_to_dim,
|
||||
const array& in) {
|
||||
Shape output_shape(expressions.size());
|
||||
int dim_to_infer = -1;
|
||||
for (int i = 0, j = 0; i < expression_.size(); ++i) {
|
||||
auto& e = expression_[i];
|
||||
uint64_t size = 1;
|
||||
for (int i = 0; i < expressions.size(); ++i) {
|
||||
auto& e = expressions[i];
|
||||
if (auto pv = std::get_if<int>(&e); pv) {
|
||||
if (*pv == -1) {
|
||||
dim_to_infer = i;
|
||||
@ -2885,20 +2882,66 @@ std::vector<Shape> Reshape::output_shapes(const std::vector<array>& inputs) {
|
||||
}
|
||||
} else {
|
||||
auto& s = std::get<std::string>(e);
|
||||
output_shape[i] = in.shape()[j++];
|
||||
if (s.size() == 1) {
|
||||
output_shape[i] = in.shape()[char_to_dim.at(s[0])];
|
||||
} else {
|
||||
int d;
|
||||
size_t loc = 0;
|
||||
char op = 0;
|
||||
while (loc < s.size()) {
|
||||
int res;
|
||||
if (std::isdigit(s[loc])) {
|
||||
char* p;
|
||||
res = std::strtol(s.c_str() + loc, &p, 10);
|
||||
loc = (p - s.c_str());
|
||||
} else if (std::isalpha(s[loc])) {
|
||||
res = in.shape()[char_to_dim.at(s[loc++])];
|
||||
} else if (s[loc] == '*' || s[loc] == '/') {
|
||||
op = s[loc++];
|
||||
continue;
|
||||
}
|
||||
|
||||
if (op == '*') {
|
||||
d *= res;
|
||||
} else if (op == '/') {
|
||||
d /= res;
|
||||
} else {
|
||||
d = res;
|
||||
}
|
||||
}
|
||||
output_shape[i] = d;
|
||||
}
|
||||
}
|
||||
size *= output_shape[i];
|
||||
}
|
||||
|
||||
if (dim_to_infer >= 0) {
|
||||
uint64_t output_size = 1;
|
||||
for (int i = 0; i < output_shape.size(); ++i) {
|
||||
if (i != dim_to_infer) {
|
||||
output_size *= output_shape[i];
|
||||
}
|
||||
if (size == 0) {
|
||||
throw std::invalid_argument(
|
||||
"[dynamic_reshape] Cannot infer the shape of an empty array.");
|
||||
}
|
||||
output_shape[dim_to_infer] = in.size() / output_size;
|
||||
auto d = in.size() / size;
|
||||
output_shape[dim_to_infer] = d;
|
||||
size *= d;
|
||||
}
|
||||
return {std::move(output_shape)};
|
||||
|
||||
if (in.size() != size) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dynamic_reshape] Cannot reshape array of size " << in.size()
|
||||
<< " into shape " << output_shape << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
return output_shape;
|
||||
}
|
||||
|
||||
std::vector<Shape> Reshape::output_shapes(const std::vector<array>& inputs) {
|
||||
// Only allowed to dynamically reshape when the shape is {}
|
||||
if (expressions_.empty() && !shape_.empty()) {
|
||||
throw std::invalid_argument(
|
||||
"[Reshape::output_shapes] Unable to infer output shape.");
|
||||
}
|
||||
return {shape_from_expressions(expressions_, char_to_dim_, inputs[0])};
|
||||
}
|
||||
|
||||
std::vector<array> Reduce::vjp(
|
||||
|
@ -1611,8 +1611,11 @@ class Reshape : public UnaryPrimitive {
|
||||
|
||||
explicit Reshape(
|
||||
Stream stream,
|
||||
std::vector<std::variant<int, std::string>> expression)
|
||||
: UnaryPrimitive(stream), expression_(std::move(expression)) {}
|
||||
std::vector<std::variant<int, std::string>> expressions,
|
||||
std::unordered_map<char, int> char_to_dim)
|
||||
: UnaryPrimitive(stream),
|
||||
expressions_(std::move(expressions)),
|
||||
char_to_dim_(std::move(char_to_dim)) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1623,9 +1626,15 @@ class Reshape : public UnaryPrimitive {
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
|
||||
static Shape shape_from_expressions(
|
||||
const std::vector<std::variant<int, std::string>>& expressions,
|
||||
const std::unordered_map<char, int>& char_to_dim,
|
||||
const array& in);
|
||||
|
||||
private:
|
||||
Shape shape_;
|
||||
std::vector<std::variant<int, std::string>> expression_;
|
||||
std::vector<std::variant<int, std::string>> expressions_;
|
||||
std::unordered_map<char, int> char_to_dim_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
|
||||
|
@ -4,6 +4,7 @@
|
||||
#define _USE_MATH_DEFINES
|
||||
|
||||
#include <cmath>
|
||||
#include <iostream> // TODO
|
||||
#include <numeric>
|
||||
|
||||
#include "doctest/doctest.h"
|
||||
@ -3777,6 +3778,10 @@ TEST_CASE("test dynamic reshape") {
|
||||
// Bad character
|
||||
CHECK_THROWS(dynamic_reshape(x, {"&", 1, 1}));
|
||||
|
||||
// Malformed
|
||||
CHECK_THROWS(dynamic_reshape(x, {"+a", 1, 1}));
|
||||
CHECK_THROWS(dynamic_reshape(x, {"a+", 1, 1}));
|
||||
|
||||
// No dim in string
|
||||
CHECK_THROWS(dynamic_reshape(x, {"1", 1, 1}));
|
||||
|
||||
@ -3785,7 +3790,41 @@ TEST_CASE("test dynamic reshape") {
|
||||
|
||||
// Too many dims
|
||||
CHECK_THROWS(dynamic_reshape(x, {"a", "b", "c", "d"}));
|
||||
CHECK_THROWS(dynamic_reshape(x, {"abcd", 1, 1}));
|
||||
|
||||
// Too many inferred dims
|
||||
CHECK_THROWS(dynamic_reshape(x, {"a", -1, -1}));
|
||||
|
||||
// Bad sizes
|
||||
x = zeros({2, 2, 2});
|
||||
CHECK_THROWS_AS(dynamic_reshape(x, {7}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(dynamic_reshape(x, {-1, 7}), std::invalid_argument);
|
||||
|
||||
// Works with empty array
|
||||
x = array({});
|
||||
auto y = dynamic_reshape(x, {0, 0, 0});
|
||||
CHECK_EQ(y.shape(), Shape{0, 0, 0});
|
||||
CHECK_THROWS_AS(dynamic_reshape(x, {}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(dynamic_reshape(x, {1}), std::invalid_argument);
|
||||
y = dynamic_reshape(x, {1, 5, 0});
|
||||
CHECK_EQ(y.shape(), Shape{1, 5, 0});
|
||||
|
||||
x = array({1, 2, 3});
|
||||
y = dynamic_reshape(x, {"a", 1, 1});
|
||||
CHECK_EQ(y.shape(), Shape{3, 1, 1});
|
||||
|
||||
x = zeros({2, 2});
|
||||
y = dynamic_reshape(x, {"a*b"});
|
||||
CHECK_EQ(y.shape(), Shape{4});
|
||||
|
||||
y = dynamic_reshape(x, {"2*a"});
|
||||
CHECK_EQ(y.shape(), Shape{4});
|
||||
|
||||
x = zeros({2, 20});
|
||||
y = dynamic_reshape(x, {"a*20"});
|
||||
CHECK_EQ(y.shape(), Shape{40});
|
||||
|
||||
x = zeros({2, 20});
|
||||
y = dynamic_reshape(x, {"a", "b/10", 10});
|
||||
CHECK_EQ(y.shape(), Shape{2, 2, 10});
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user