mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-06 08:24:39 +08:00
Compare commits
3 Commits
cuda-sdpa-
...
dynamic_re
Author | SHA1 | Date | |
---|---|---|---|
![]() |
0c1155faf5 | ||
![]() |
2b9c24c517 | ||
![]() |
ee59d50293 |
@@ -80,7 +80,8 @@ bool allows_shapeless(const Primitive& p) {
|
||||
typeid(p) == typeid(Partition) || typeid(p) == typeid(Select) ||
|
||||
typeid(p) == typeid(NumberOfElements) || typeid(p) == typeid(Gather) ||
|
||||
typeid(p) == typeid(Transpose) || typeid(p) == typeid(Concatenate) ||
|
||||
typeid(p) == typeid(Matmul) || typeid(p) == typeid(QuantizedMatmul) ||
|
||||
typeid(p) == typeid(Reshape) || typeid(p) == typeid(Matmul) ||
|
||||
typeid(p) == typeid(QuantizedMatmul) ||
|
||||
typeid(p) == typeid(fast::AffineQuantize) ||
|
||||
typeid(p) == typeid(fast::LayerNorm) ||
|
||||
typeid(p) == typeid(fast::RMSNorm) || typeid(p) == typeid(fast::RoPE) ||
|
||||
|
72
mlx/ops.cpp
72
mlx/ops.cpp
@@ -403,6 +403,78 @@ array reshape(const array& a, Shape shape, StreamOrDevice s /* = {} */) {
|
||||
return array(std::move(shape), a.dtype(), std::move(p), {a});
|
||||
}
|
||||
|
||||
// Variant of string and int for the expressions
|
||||
array dynamic_reshape(
|
||||
const array& a,
|
||||
std::vector<std::variant<int, std::string>> expressions,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
// Reshape to scalar is not dynamic
|
||||
if (expressions.empty()) {
|
||||
return reshape(a, {}, s);
|
||||
}
|
||||
|
||||
// Validate expressions:
|
||||
// - At most one item in expressions is -1
|
||||
// - Any string expression should have a letter
|
||||
// - At most a.ndim() unique letters
|
||||
// - Only valid characters in string (alphabet, integer, *, /)
|
||||
bool infer_dim = false;
|
||||
std::unordered_map<char, int> char_to_dim;
|
||||
for (auto& e : expressions) {
|
||||
if (auto pv = std::get_if<int>(&e); pv) {
|
||||
if (*pv == -1) {
|
||||
if (infer_dim) {
|
||||
throw std::invalid_argument(
|
||||
"[dynamic_reshape] Cannot infer more than one dimension.");
|
||||
}
|
||||
infer_dim = true;
|
||||
}
|
||||
} else {
|
||||
auto& s = std::get<std::string>(e);
|
||||
bool has_alpha = false;
|
||||
for (auto c : s) {
|
||||
if (isalpha(c)) {
|
||||
has_alpha = true;
|
||||
char_to_dim.insert({c, char_to_dim.size()});
|
||||
} else if (!isdigit(c) && c != '*' && c != '/') {
|
||||
std::ostringstream msg;
|
||||
msg << "[dynamic_reshape] Invalid character in string expression \""
|
||||
<< s << "\".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
if (!has_alpha) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dynamic_reshape] String expression must contain at least "
|
||||
<< "one alphabetic character but got: \"" << s << "\".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (!isdigit(s[0]) && !isalpha(s[0]) && !isdigit(s.back()) &&
|
||||
!isalpha(s.back())) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dynamic_reshape] String expression must start and end with "
|
||||
<< "integer or letter but got: \"" << s << "\".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
if (char_to_dim.size() > a.ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dynamic_reshape] Expressions contain " << char_to_dim.size()
|
||||
<< " abstract dimensions for array with only " << a.ndim()
|
||||
<< " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
auto output_shape =
|
||||
Reshape::shape_from_expressions(expressions, char_to_dim, a);
|
||||
return array(
|
||||
std::move(output_shape),
|
||||
a.dtype(),
|
||||
std::make_shared<Reshape>(
|
||||
to_stream(s), std::move(expressions), std::move(char_to_dim)),
|
||||
{a});
|
||||
}
|
||||
|
||||
array flatten(
|
||||
const array& a,
|
||||
int start_axis,
|
||||
|
@@ -117,6 +117,12 @@ array triu(array x, int k = 0, StreamOrDevice s = {});
|
||||
/** Reshape an array to the given shape. */
|
||||
array reshape(const array& a, Shape shape, StreamOrDevice s = {});
|
||||
|
||||
/** Dynamically reshape an array based on the given expressions. */
|
||||
array dynamic_reshape(
|
||||
const array& a,
|
||||
std::vector<std::variant<int, std::string>> expressions,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Flatten the dimensions in the range `[start_axis, end_axis]` . */
|
||||
array flatten(
|
||||
const array& a,
|
||||
|
@@ -2858,9 +2858,92 @@ std::vector<array> Reshape::jvp(
|
||||
|
||||
bool Reshape::is_equivalent(const Primitive& other) const {
|
||||
const Reshape& r_other = static_cast<const Reshape&>(other);
|
||||
if (!expressions_.empty()) {
|
||||
return expressions_ == r_other.expressions_;
|
||||
}
|
||||
return shape_ == r_other.shape_;
|
||||
}
|
||||
|
||||
Shape Reshape::shape_from_expressions(
|
||||
const std::vector<std::variant<int, std::string>>& expressions,
|
||||
const std::unordered_map<char, int>& char_to_dim,
|
||||
const array& in) {
|
||||
Shape output_shape(expressions.size());
|
||||
int dim_to_infer = -1;
|
||||
uint64_t size = 1;
|
||||
for (int i = 0; i < expressions.size(); ++i) {
|
||||
auto& e = expressions[i];
|
||||
if (auto pv = std::get_if<int>(&e); pv) {
|
||||
if (*pv == -1) {
|
||||
dim_to_infer = i;
|
||||
continue;
|
||||
} else {
|
||||
output_shape[i] = *pv;
|
||||
}
|
||||
} else {
|
||||
auto& s = std::get<std::string>(e);
|
||||
if (s.size() == 1) {
|
||||
output_shape[i] = in.shape()[char_to_dim.at(s[0])];
|
||||
} else {
|
||||
int d;
|
||||
size_t loc = 0;
|
||||
char op = 0;
|
||||
while (loc < s.size()) {
|
||||
int res;
|
||||
if (std::isdigit(s[loc])) {
|
||||
char* p;
|
||||
res = std::strtol(s.c_str() + loc, &p, 10);
|
||||
loc = (p - s.c_str());
|
||||
} else if (std::isalpha(s[loc])) {
|
||||
res = in.shape()[char_to_dim.at(s[loc++])];
|
||||
} else if (s[loc] == '*' || s[loc] == '/') {
|
||||
op = s[loc++];
|
||||
continue;
|
||||
}
|
||||
|
||||
if (op == '*') {
|
||||
d *= res;
|
||||
} else if (op == '/') {
|
||||
d /= res;
|
||||
} else {
|
||||
d = res;
|
||||
}
|
||||
}
|
||||
output_shape[i] = d;
|
||||
}
|
||||
}
|
||||
size *= output_shape[i];
|
||||
}
|
||||
|
||||
if (dim_to_infer >= 0) {
|
||||
if (size == 0) {
|
||||
throw std::invalid_argument(
|
||||
"[dynamic_reshape] Cannot infer the shape of an empty array.");
|
||||
}
|
||||
auto d = in.size() / size;
|
||||
output_shape[dim_to_infer] = d;
|
||||
size *= d;
|
||||
}
|
||||
|
||||
if (in.size() != size) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dynamic_reshape] Cannot reshape array of size " << in.size()
|
||||
<< " into shape " << output_shape << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
return output_shape;
|
||||
}
|
||||
|
||||
std::vector<Shape> Reshape::output_shapes(const std::vector<array>& inputs) {
|
||||
// Only allowed to dynamically reshape when the shape is {}
|
||||
if (expressions_.empty() && !shape_.empty()) {
|
||||
throw std::invalid_argument(
|
||||
"[Reshape::output_shapes] Unable to infer output shape.");
|
||||
}
|
||||
return {shape_from_expressions(expressions_, char_to_dim_, inputs[0])};
|
||||
}
|
||||
|
||||
std::vector<array> Reduce::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
|
@@ -1609,6 +1609,14 @@ class Reshape : public UnaryPrimitive {
|
||||
explicit Reshape(Stream stream, const Shape& shape)
|
||||
: UnaryPrimitive(stream), shape_(shape) {}
|
||||
|
||||
explicit Reshape(
|
||||
Stream stream,
|
||||
std::vector<std::variant<int, std::string>> expressions,
|
||||
std::unordered_map<char, int> char_to_dim)
|
||||
: UnaryPrimitive(stream),
|
||||
expressions_(std::move(expressions)),
|
||||
char_to_dim_(std::move(char_to_dim)) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
@@ -1616,9 +1624,17 @@ class Reshape : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Reshape)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
|
||||
static Shape shape_from_expressions(
|
||||
const std::vector<std::variant<int, std::string>>& expressions,
|
||||
const std::unordered_map<char, int>& char_to_dim,
|
||||
const array& in);
|
||||
|
||||
private:
|
||||
Shape shape_;
|
||||
std::vector<std::variant<int, std::string>> expressions_;
|
||||
std::unordered_map<char, int> char_to_dim_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
|
||||
|
@@ -4880,4 +4880,27 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: The imaginary part of ``a``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"dynamic_reshape",
|
||||
&dynamic_reshape,
|
||||
nb::arg(),
|
||||
"expressions"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def dynamic_reshape(a: array, /, expressions: Sequence[Union[int, str]], *, stream: "
|
||||
"Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Dynamically reshape an array based on the given expression.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
expressions (tuple(int or str)): The expressions which determine
|
||||
the output shape.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
array: The reshaped array.
|
||||
)pbdoc");
|
||||
}
|
||||
|
@@ -809,6 +809,29 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
out = fun(*inputs)
|
||||
self.assertTrue(mx.allclose(out, mx.full((2, 2), 20)))
|
||||
|
||||
def test_compile_shapeless_with_reshape(self):
|
||||
def fun(a):
|
||||
return mx.reshape(a, (4, 7, 4, 2))
|
||||
|
||||
cfun = mx.compile(fun, shapeless=True)
|
||||
|
||||
a = mx.zeros((4, 7, 8))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
b = cfun(a)
|
||||
|
||||
def fun(a):
|
||||
return mx.dynamic_reshape(a, ("B", "L", 4, 2))
|
||||
|
||||
cfun = mx.compile(fun, shapeless=True)
|
||||
|
||||
b = cfun(a)
|
||||
self.assertEqual(b.shape, (4, 7, 4, 2))
|
||||
|
||||
a = mx.zeros((4, 9, 8))
|
||||
b = cfun(a)
|
||||
self.assertEqual(b.shape, (4, 9, 4, 2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@@ -2713,6 +2713,21 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(mx.imag(z).dtype, mx.float32)
|
||||
self.assertTrue(mx.array_equal(mx.imag(z), y))
|
||||
|
||||
def test_dynamic_reshape(self):
|
||||
a = mx.array(1)[None, None]
|
||||
a = mx.dynamic_reshape(a, ())
|
||||
self.assertEqual(a.shape, ())
|
||||
|
||||
a = mx.zeros((4, 4, 4))
|
||||
b = mx.dynamic_reshape(a, ("a", "b", "c"))
|
||||
self.assertEqual(b.shape, (4, 4, 4))
|
||||
|
||||
b = mx.dynamic_reshape(a, ("a*b", "c"))
|
||||
self.assertEqual(b.shape, (4 * 4, 4))
|
||||
|
||||
b = mx.dynamic_reshape(a, ("a*b*c", 1, 1))
|
||||
self.assertEqual(b.shape, (4 * 4 * 4, 1, 1))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@@ -4,6 +4,7 @@
|
||||
#define _USE_MATH_DEFINES
|
||||
|
||||
#include <cmath>
|
||||
#include <iostream> // TODO
|
||||
#include <numeric>
|
||||
|
||||
#include "doctest/doctest.h"
|
||||
@@ -3769,3 +3770,61 @@ TEST_CASE("test contiguous") {
|
||||
CHECK(x.flags().col_contiguous);
|
||||
CHECK_EQ(x.strides(), decltype(x.strides()){1, 2});
|
||||
}
|
||||
|
||||
TEST_CASE("test dynamic reshape") {
|
||||
auto x = array({1}, {1, 1, 1});
|
||||
CHECK_EQ(dynamic_reshape(x, {}).shape(), Shape{});
|
||||
|
||||
// Bad character
|
||||
CHECK_THROWS(dynamic_reshape(x, {"&", 1, 1}));
|
||||
|
||||
// Malformed
|
||||
CHECK_THROWS(dynamic_reshape(x, {"+a", 1, 1}));
|
||||
CHECK_THROWS(dynamic_reshape(x, {"a+", 1, 1}));
|
||||
|
||||
// No dim in string
|
||||
CHECK_THROWS(dynamic_reshape(x, {"1", 1, 1}));
|
||||
|
||||
// Too many dims
|
||||
CHECK_THROWS(dynamic_reshape(x, {"abcd", 1, 1}));
|
||||
|
||||
// Too many dims
|
||||
CHECK_THROWS(dynamic_reshape(x, {"a", "b", "c", "d"}));
|
||||
CHECK_THROWS(dynamic_reshape(x, {"abcd", 1, 1}));
|
||||
|
||||
// Too many inferred dims
|
||||
CHECK_THROWS(dynamic_reshape(x, {"a", -1, -1}));
|
||||
|
||||
// Bad sizes
|
||||
x = zeros({2, 2, 2});
|
||||
CHECK_THROWS_AS(dynamic_reshape(x, {7}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(dynamic_reshape(x, {-1, 7}), std::invalid_argument);
|
||||
|
||||
// Works with empty array
|
||||
x = array({});
|
||||
auto y = dynamic_reshape(x, {0, 0, 0});
|
||||
CHECK_EQ(y.shape(), Shape{0, 0, 0});
|
||||
CHECK_THROWS_AS(dynamic_reshape(x, {}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(dynamic_reshape(x, {1}), std::invalid_argument);
|
||||
y = dynamic_reshape(x, {1, 5, 0});
|
||||
CHECK_EQ(y.shape(), Shape{1, 5, 0});
|
||||
|
||||
x = array({1, 2, 3});
|
||||
y = dynamic_reshape(x, {"a", 1, 1});
|
||||
CHECK_EQ(y.shape(), Shape{3, 1, 1});
|
||||
|
||||
x = zeros({2, 2});
|
||||
y = dynamic_reshape(x, {"a*b"});
|
||||
CHECK_EQ(y.shape(), Shape{4});
|
||||
|
||||
y = dynamic_reshape(x, {"2*a"});
|
||||
CHECK_EQ(y.shape(), Shape{4});
|
||||
|
||||
x = zeros({2, 20});
|
||||
y = dynamic_reshape(x, {"a*20"});
|
||||
CHECK_EQ(y.shape(), Shape{40});
|
||||
|
||||
x = zeros({2, 20});
|
||||
y = dynamic_reshape(x, {"a", "b/10", 10});
|
||||
CHECK_EQ(y.shape(), Shape{2, 2, 10});
|
||||
}
|
||||
|
Reference in New Issue
Block a user