mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
works
This commit is contained in:
parent
ee59d50293
commit
2b9c24c517
21
mlx/ops.cpp
21
mlx/ops.cpp
@ -419,7 +419,7 @@ array dynamic_reshape(
|
|||||||
// - At most a.ndim() unique letters
|
// - At most a.ndim() unique letters
|
||||||
// - Only valid characters in string (alphabet, integer, *, /)
|
// - Only valid characters in string (alphabet, integer, *, /)
|
||||||
bool infer_dim = false;
|
bool infer_dim = false;
|
||||||
std::unordered_set<char> dims;
|
std::unordered_map<char, int> char_to_dim;
|
||||||
for (auto& e : expressions) {
|
for (auto& e : expressions) {
|
||||||
if (auto pv = std::get_if<int>(&e); pv) {
|
if (auto pv = std::get_if<int>(&e); pv) {
|
||||||
if (*pv == -1) {
|
if (*pv == -1) {
|
||||||
@ -435,7 +435,7 @@ array dynamic_reshape(
|
|||||||
for (auto c : s) {
|
for (auto c : s) {
|
||||||
if (isalpha(c)) {
|
if (isalpha(c)) {
|
||||||
has_alpha = true;
|
has_alpha = true;
|
||||||
dims.insert(c);
|
char_to_dim.insert({c, char_to_dim.size()});
|
||||||
} else if (!isdigit(c) && c != '*' && c != '/') {
|
} else if (!isdigit(c) && c != '*' && c != '/') {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[dynamic_reshape] Invalid character in string expression \""
|
msg << "[dynamic_reshape] Invalid character in string expression \""
|
||||||
@ -449,20 +449,29 @@ array dynamic_reshape(
|
|||||||
<< "one alphabetic character but got: \"" << s << "\".";
|
<< "one alphabetic character but got: \"" << s << "\".";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
if (!isdigit(s[0]) && !isalpha(s[0]) && !isdigit(s.back()) &&
|
||||||
|
!isalpha(s.back())) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[dynamic_reshape] String expression must start and end with "
|
||||||
|
<< "integer or letter but got: \"" << s << "\".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (dims.size() >= a.ndim()) {
|
if (char_to_dim.size() > a.ndim()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[dynamic_reshape] Expressions contain " << dims.size()
|
msg << "[dynamic_reshape] Expressions contain " << char_to_dim.size()
|
||||||
<< " abstract dimensions for array with only " << a.ndim()
|
<< " abstract dimensions for array with only " << a.ndim()
|
||||||
<< " dimensions.";
|
<< " dimensions.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
auto output_shape = Shape{}; // Reshape::shape_from_expression(a, expression);
|
auto output_shape =
|
||||||
|
Reshape::shape_from_expressions(expressions, char_to_dim, a);
|
||||||
return array(
|
return array(
|
||||||
std::move(output_shape),
|
std::move(output_shape),
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
std::make_shared<Reshape>(to_stream(s), std::move(expressions)),
|
std::make_shared<Reshape>(
|
||||||
|
to_stream(s), std::move(expressions), std::move(char_to_dim)),
|
||||||
{a});
|
{a});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2858,24 +2858,21 @@ std::vector<array> Reshape::jvp(
|
|||||||
|
|
||||||
bool Reshape::is_equivalent(const Primitive& other) const {
|
bool Reshape::is_equivalent(const Primitive& other) const {
|
||||||
const Reshape& r_other = static_cast<const Reshape&>(other);
|
const Reshape& r_other = static_cast<const Reshape&>(other);
|
||||||
if (!expression_.empty()) {
|
if (!expressions_.empty()) {
|
||||||
return expression_ == r_other.expression_;
|
return expressions_ == r_other.expressions_;
|
||||||
}
|
}
|
||||||
return shape_ == r_other.shape_;
|
return shape_ == r_other.shape_;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Shape> Reshape::output_shapes(const std::vector<array>& inputs) {
|
Shape Reshape::shape_from_expressions(
|
||||||
// Only allowed to dynamically reshape when the shape is {}
|
const std::vector<std::variant<int, std::string>>& expressions,
|
||||||
if (expression_.empty() && !shape_.empty()) {
|
const std::unordered_map<char, int>& char_to_dim,
|
||||||
throw std::invalid_argument(
|
const array& in) {
|
||||||
"[Reshape::output_shapes] Unable to infer output shape.");
|
Shape output_shape(expressions.size());
|
||||||
}
|
|
||||||
|
|
||||||
auto& in = inputs[0];
|
|
||||||
Shape output_shape(expression_.size());
|
|
||||||
int dim_to_infer = -1;
|
int dim_to_infer = -1;
|
||||||
for (int i = 0, j = 0; i < expression_.size(); ++i) {
|
uint64_t size = 1;
|
||||||
auto& e = expression_[i];
|
for (int i = 0; i < expressions.size(); ++i) {
|
||||||
|
auto& e = expressions[i];
|
||||||
if (auto pv = std::get_if<int>(&e); pv) {
|
if (auto pv = std::get_if<int>(&e); pv) {
|
||||||
if (*pv == -1) {
|
if (*pv == -1) {
|
||||||
dim_to_infer = i;
|
dim_to_infer = i;
|
||||||
@ -2885,20 +2882,66 @@ std::vector<Shape> Reshape::output_shapes(const std::vector<array>& inputs) {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
auto& s = std::get<std::string>(e);
|
auto& s = std::get<std::string>(e);
|
||||||
output_shape[i] = in.shape()[j++];
|
if (s.size() == 1) {
|
||||||
|
output_shape[i] = in.shape()[char_to_dim.at(s[0])];
|
||||||
|
} else {
|
||||||
|
int d;
|
||||||
|
size_t loc = 0;
|
||||||
|
char op = 0;
|
||||||
|
while (loc < s.size()) {
|
||||||
|
int res;
|
||||||
|
if (std::isdigit(s[loc])) {
|
||||||
|
char* p;
|
||||||
|
res = std::strtol(s.c_str() + loc, &p, 10);
|
||||||
|
loc = (p - s.c_str());
|
||||||
|
} else if (std::isalpha(s[loc])) {
|
||||||
|
res = in.shape()[char_to_dim.at(s[loc++])];
|
||||||
|
} else if (s[loc] == '*' || s[loc] == '/') {
|
||||||
|
op = s[loc++];
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op == '*') {
|
||||||
|
d *= res;
|
||||||
|
} else if (op == '/') {
|
||||||
|
d /= res;
|
||||||
|
} else {
|
||||||
|
d = res;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output_shape[i] = d;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
size *= output_shape[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (dim_to_infer >= 0) {
|
if (dim_to_infer >= 0) {
|
||||||
uint64_t output_size = 1;
|
if (size == 0) {
|
||||||
for (int i = 0; i < output_shape.size(); ++i) {
|
throw std::invalid_argument(
|
||||||
if (i != dim_to_infer) {
|
"[dynamic_reshape] Cannot infer the shape of an empty array.");
|
||||||
output_size *= output_shape[i];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
output_shape[dim_to_infer] = in.size() / output_size;
|
auto d = in.size() / size;
|
||||||
|
output_shape[dim_to_infer] = d;
|
||||||
|
size *= d;
|
||||||
}
|
}
|
||||||
return {std::move(output_shape)};
|
|
||||||
|
if (in.size() != size) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[dynamic_reshape] Cannot reshape array of size " << in.size()
|
||||||
|
<< " into shape " << output_shape << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
return output_shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Shape> Reshape::output_shapes(const std::vector<array>& inputs) {
|
||||||
|
// Only allowed to dynamically reshape when the shape is {}
|
||||||
|
if (expressions_.empty() && !shape_.empty()) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[Reshape::output_shapes] Unable to infer output shape.");
|
||||||
|
}
|
||||||
|
return {shape_from_expressions(expressions_, char_to_dim_, inputs[0])};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> Reduce::vjp(
|
std::vector<array> Reduce::vjp(
|
||||||
|
@ -1611,8 +1611,11 @@ class Reshape : public UnaryPrimitive {
|
|||||||
|
|
||||||
explicit Reshape(
|
explicit Reshape(
|
||||||
Stream stream,
|
Stream stream,
|
||||||
std::vector<std::variant<int, std::string>> expression)
|
std::vector<std::variant<int, std::string>> expressions,
|
||||||
: UnaryPrimitive(stream), expression_(std::move(expression)) {}
|
std::unordered_map<char, int> char_to_dim)
|
||||||
|
: UnaryPrimitive(stream),
|
||||||
|
expressions_(std::move(expressions)),
|
||||||
|
char_to_dim_(std::move(char_to_dim)) {}
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1623,9 +1626,15 @@ class Reshape : public UnaryPrimitive {
|
|||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
|
|
||||||
|
static Shape shape_from_expressions(
|
||||||
|
const std::vector<std::variant<int, std::string>>& expressions,
|
||||||
|
const std::unordered_map<char, int>& char_to_dim,
|
||||||
|
const array& in);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Shape shape_;
|
Shape shape_;
|
||||||
std::vector<std::variant<int, std::string>> expression_;
|
std::vector<std::variant<int, std::string>> expressions_;
|
||||||
|
std::unordered_map<char, int> char_to_dim_;
|
||||||
|
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
#define _USE_MATH_DEFINES
|
#define _USE_MATH_DEFINES
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <iostream> // TODO
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
#include "doctest/doctest.h"
|
#include "doctest/doctest.h"
|
||||||
@ -3777,6 +3778,10 @@ TEST_CASE("test dynamic reshape") {
|
|||||||
// Bad character
|
// Bad character
|
||||||
CHECK_THROWS(dynamic_reshape(x, {"&", 1, 1}));
|
CHECK_THROWS(dynamic_reshape(x, {"&", 1, 1}));
|
||||||
|
|
||||||
|
// Malformed
|
||||||
|
CHECK_THROWS(dynamic_reshape(x, {"+a", 1, 1}));
|
||||||
|
CHECK_THROWS(dynamic_reshape(x, {"a+", 1, 1}));
|
||||||
|
|
||||||
// No dim in string
|
// No dim in string
|
||||||
CHECK_THROWS(dynamic_reshape(x, {"1", 1, 1}));
|
CHECK_THROWS(dynamic_reshape(x, {"1", 1, 1}));
|
||||||
|
|
||||||
@ -3785,7 +3790,41 @@ TEST_CASE("test dynamic reshape") {
|
|||||||
|
|
||||||
// Too many dims
|
// Too many dims
|
||||||
CHECK_THROWS(dynamic_reshape(x, {"a", "b", "c", "d"}));
|
CHECK_THROWS(dynamic_reshape(x, {"a", "b", "c", "d"}));
|
||||||
|
CHECK_THROWS(dynamic_reshape(x, {"abcd", 1, 1}));
|
||||||
|
|
||||||
// Too many inferred dims
|
// Too many inferred dims
|
||||||
CHECK_THROWS(dynamic_reshape(x, {"a", -1, -1}));
|
CHECK_THROWS(dynamic_reshape(x, {"a", -1, -1}));
|
||||||
|
|
||||||
|
// Bad sizes
|
||||||
|
x = zeros({2, 2, 2});
|
||||||
|
CHECK_THROWS_AS(dynamic_reshape(x, {7}), std::invalid_argument);
|
||||||
|
CHECK_THROWS_AS(dynamic_reshape(x, {-1, 7}), std::invalid_argument);
|
||||||
|
|
||||||
|
// Works with empty array
|
||||||
|
x = array({});
|
||||||
|
auto y = dynamic_reshape(x, {0, 0, 0});
|
||||||
|
CHECK_EQ(y.shape(), Shape{0, 0, 0});
|
||||||
|
CHECK_THROWS_AS(dynamic_reshape(x, {}), std::invalid_argument);
|
||||||
|
CHECK_THROWS_AS(dynamic_reshape(x, {1}), std::invalid_argument);
|
||||||
|
y = dynamic_reshape(x, {1, 5, 0});
|
||||||
|
CHECK_EQ(y.shape(), Shape{1, 5, 0});
|
||||||
|
|
||||||
|
x = array({1, 2, 3});
|
||||||
|
y = dynamic_reshape(x, {"a", 1, 1});
|
||||||
|
CHECK_EQ(y.shape(), Shape{3, 1, 1});
|
||||||
|
|
||||||
|
x = zeros({2, 2});
|
||||||
|
y = dynamic_reshape(x, {"a*b"});
|
||||||
|
CHECK_EQ(y.shape(), Shape{4});
|
||||||
|
|
||||||
|
y = dynamic_reshape(x, {"2*a"});
|
||||||
|
CHECK_EQ(y.shape(), Shape{4});
|
||||||
|
|
||||||
|
x = zeros({2, 20});
|
||||||
|
y = dynamic_reshape(x, {"a*20"});
|
||||||
|
CHECK_EQ(y.shape(), Shape{40});
|
||||||
|
|
||||||
|
x = zeros({2, 20});
|
||||||
|
y = dynamic_reshape(x, {"a", "b/10", 10});
|
||||||
|
CHECK_EQ(y.shape(), Shape{2, 2, 10});
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user