mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-09 10:46:39 +08:00
try dynamic reshape
This commit is contained in:
parent
40c62c1321
commit
ee59d50293
63
mlx/ops.cpp
63
mlx/ops.cpp
@ -403,6 +403,69 @@ array reshape(const array& a, Shape shape, StreamOrDevice s /* = {} */) {
|
|||||||
return array(std::move(shape), a.dtype(), std::move(p), {a});
|
return array(std::move(shape), a.dtype(), std::move(p), {a});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Variant of string and int for the expressions
|
||||||
|
array dynamic_reshape(
|
||||||
|
const array& a,
|
||||||
|
std::vector<std::variant<int, std::string>> expressions,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
// Reshape to scalar is not dynamic
|
||||||
|
if (expressions.empty()) {
|
||||||
|
return reshape(a, {}, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate expressions:
|
||||||
|
// - At most one item in expressions is -1
|
||||||
|
// - Any string expression should have a letter
|
||||||
|
// - At most a.ndim() unique letters
|
||||||
|
// - Only valid characters in string (alphabet, integer, *, /)
|
||||||
|
bool infer_dim = false;
|
||||||
|
std::unordered_set<char> dims;
|
||||||
|
for (auto& e : expressions) {
|
||||||
|
if (auto pv = std::get_if<int>(&e); pv) {
|
||||||
|
if (*pv == -1) {
|
||||||
|
if (infer_dim) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[dynamic_reshape] Cannot infer more than one dimension.");
|
||||||
|
}
|
||||||
|
infer_dim = true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto& s = std::get<std::string>(e);
|
||||||
|
bool has_alpha = false;
|
||||||
|
for (auto c : s) {
|
||||||
|
if (isalpha(c)) {
|
||||||
|
has_alpha = true;
|
||||||
|
dims.insert(c);
|
||||||
|
} else if (!isdigit(c) && c != '*' && c != '/') {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[dynamic_reshape] Invalid character in string expression \""
|
||||||
|
<< s << "\".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!has_alpha) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[dynamic_reshape] String expression must contain at least "
|
||||||
|
<< "one alphabetic character but got: \"" << s << "\".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (dims.size() >= a.ndim()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[dynamic_reshape] Expressions contain " << dims.size()
|
||||||
|
<< " abstract dimensions for array with only " << a.ndim()
|
||||||
|
<< " dimensions.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
auto output_shape = Shape{}; // Reshape::shape_from_expression(a, expression);
|
||||||
|
return array(
|
||||||
|
std::move(output_shape),
|
||||||
|
a.dtype(),
|
||||||
|
std::make_shared<Reshape>(to_stream(s), std::move(expressions)),
|
||||||
|
{a});
|
||||||
|
}
|
||||||
|
|
||||||
array flatten(
|
array flatten(
|
||||||
const array& a,
|
const array& a,
|
||||||
int start_axis,
|
int start_axis,
|
||||||
|
@ -117,6 +117,12 @@ array triu(array x, int k = 0, StreamOrDevice s = {});
|
|||||||
/** Reshape an array to the given shape. */
|
/** Reshape an array to the given shape. */
|
||||||
array reshape(const array& a, Shape shape, StreamOrDevice s = {});
|
array reshape(const array& a, Shape shape, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Dynamically reshape an array based on the given expressions. */
|
||||||
|
array dynamic_reshape(
|
||||||
|
const array& a,
|
||||||
|
std::vector<std::variant<int, std::string>> expressions,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Flatten the dimensions in the range `[start_axis, end_axis]` . */
|
/** Flatten the dimensions in the range `[start_axis, end_axis]` . */
|
||||||
array flatten(
|
array flatten(
|
||||||
const array& a,
|
const array& a,
|
||||||
|
@ -2858,9 +2858,49 @@ std::vector<array> Reshape::jvp(
|
|||||||
|
|
||||||
bool Reshape::is_equivalent(const Primitive& other) const {
|
bool Reshape::is_equivalent(const Primitive& other) const {
|
||||||
const Reshape& r_other = static_cast<const Reshape&>(other);
|
const Reshape& r_other = static_cast<const Reshape&>(other);
|
||||||
|
if (!expression_.empty()) {
|
||||||
|
return expression_ == r_other.expression_;
|
||||||
|
}
|
||||||
return shape_ == r_other.shape_;
|
return shape_ == r_other.shape_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<Shape> Reshape::output_shapes(const std::vector<array>& inputs) {
|
||||||
|
// Only allowed to dynamically reshape when the shape is {}
|
||||||
|
if (expression_.empty() && !shape_.empty()) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[Reshape::output_shapes] Unable to infer output shape.");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& in = inputs[0];
|
||||||
|
Shape output_shape(expression_.size());
|
||||||
|
int dim_to_infer = -1;
|
||||||
|
for (int i = 0, j = 0; i < expression_.size(); ++i) {
|
||||||
|
auto& e = expression_[i];
|
||||||
|
if (auto pv = std::get_if<int>(&e); pv) {
|
||||||
|
if (*pv == -1) {
|
||||||
|
dim_to_infer = i;
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
output_shape[i] = *pv;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto& s = std::get<std::string>(e);
|
||||||
|
output_shape[i] = in.shape()[j++];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dim_to_infer >= 0) {
|
||||||
|
uint64_t output_size = 1;
|
||||||
|
for (int i = 0; i < output_shape.size(); ++i) {
|
||||||
|
if (i != dim_to_infer) {
|
||||||
|
output_size *= output_shape[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output_shape[dim_to_infer] = in.size() / output_size;
|
||||||
|
}
|
||||||
|
return {std::move(output_shape)};
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<array> Reduce::vjp(
|
std::vector<array> Reduce::vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& cotangents,
|
const std::vector<array>& cotangents,
|
||||||
|
@ -1609,6 +1609,11 @@ class Reshape : public UnaryPrimitive {
|
|||||||
explicit Reshape(Stream stream, const Shape& shape)
|
explicit Reshape(Stream stream, const Shape& shape)
|
||||||
: UnaryPrimitive(stream), shape_(shape) {}
|
: UnaryPrimitive(stream), shape_(shape) {}
|
||||||
|
|
||||||
|
explicit Reshape(
|
||||||
|
Stream stream,
|
||||||
|
std::vector<std::variant<int, std::string>> expression)
|
||||||
|
: UnaryPrimitive(stream), expression_(std::move(expression)) {}
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
|
||||||
@ -1616,9 +1621,11 @@ class Reshape : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Reshape)
|
DEFINE_PRINT(Reshape)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Shape shape_;
|
Shape shape_;
|
||||||
|
std::vector<std::variant<int, std::string>> expression_;
|
||||||
|
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
|
|
||||||
|
@ -4880,4 +4880,27 @@ void init_ops(nb::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
array: The imaginary part of ``a``.
|
array: The imaginary part of ``a``.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"dynamic_reshape",
|
||||||
|
&dynamic_reshape,
|
||||||
|
nb::arg(),
|
||||||
|
"expression"_a,
|
||||||
|
nb::kw_only(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def dynamic_reshape(a: array, /, expression: Sequence[Union[int, str]], *, stream: "
|
||||||
|
"Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
Dynamically reshape an array based on the given expression.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array.
|
||||||
|
expression (tuple(int or str)): The expression which determines the
|
||||||
|
output shape.
|
||||||
|
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||||
|
in which case the default stream of the default device is used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The reshaped array.
|
||||||
|
)pbdoc");
|
||||||
}
|
}
|
||||||
|
@ -2713,6 +2713,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(mx.imag(z).dtype, mx.float32)
|
self.assertEqual(mx.imag(z).dtype, mx.float32)
|
||||||
self.assertTrue(mx.array_equal(mx.imag(z), y))
|
self.assertTrue(mx.array_equal(mx.imag(z), y))
|
||||||
|
|
||||||
|
def test_dynamic_reshape(self):
|
||||||
|
a = mx.array(1)[None, None]
|
||||||
|
a = mx.dynamic_reshape(a, ())
|
||||||
|
self.assertEqual(a.shape, ())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -3769,3 +3769,23 @@ TEST_CASE("test contiguous") {
|
|||||||
CHECK(x.flags().col_contiguous);
|
CHECK(x.flags().col_contiguous);
|
||||||
CHECK_EQ(x.strides(), decltype(x.strides()){1, 2});
|
CHECK_EQ(x.strides(), decltype(x.strides()){1, 2});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test dynamic reshape") {
|
||||||
|
auto x = array({1}, {1, 1, 1});
|
||||||
|
CHECK_EQ(dynamic_reshape(x, {}).shape(), Shape{});
|
||||||
|
|
||||||
|
// Bad character
|
||||||
|
CHECK_THROWS(dynamic_reshape(x, {"&", 1, 1}));
|
||||||
|
|
||||||
|
// No dim in string
|
||||||
|
CHECK_THROWS(dynamic_reshape(x, {"1", 1, 1}));
|
||||||
|
|
||||||
|
// Too many dims
|
||||||
|
CHECK_THROWS(dynamic_reshape(x, {"abcd", 1, 1}));
|
||||||
|
|
||||||
|
// Too many dims
|
||||||
|
CHECK_THROWS(dynamic_reshape(x, {"a", "b", "c", "d"}));
|
||||||
|
|
||||||
|
// Too many inferred dims
|
||||||
|
CHECK_THROWS(dynamic_reshape(x, {"a", -1, -1}));
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user