mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-14 00:39:06 +08:00
Shapeless support for zeros/ones_like (#2726)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
* shapeless support for zeros/ones_like * Improvements * fix access after moved
This commit is contained in:
32
mlx/ops.cpp
32
mlx/ops.cpp
@@ -280,16 +280,19 @@ array copy(array a, StreamOrDevice s /* = {} */) {
|
|||||||
{std::move(a)});
|
{std::move(a)});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array full_impl(array vals, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||||
|
return array(
|
||||||
|
vals.shape(),
|
||||||
|
dtype,
|
||||||
|
std::make_shared<Full>(to_stream(s)),
|
||||||
|
{astype(vals, dtype, s)});
|
||||||
|
}
|
||||||
|
|
||||||
array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s /* = {} */) {
|
array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||||
if (std::any_of(shape.begin(), shape.end(), [](auto i) { return i < 0; })) {
|
if (std::any_of(shape.begin(), shape.end(), [](auto i) { return i < 0; })) {
|
||||||
throw std::invalid_argument("[full] Negative dimensions not allowed.");
|
throw std::invalid_argument("[full] Negative dimensions not allowed.");
|
||||||
}
|
}
|
||||||
auto copied_shape = shape; // |shape| will be moved
|
return full_impl(broadcast_to(vals, std::move(shape), s), dtype, s);
|
||||||
return array(
|
|
||||||
std::move(copied_shape),
|
|
||||||
dtype,
|
|
||||||
std::make_shared<Full>(to_stream(s)),
|
|
||||||
{broadcast_to(astype(std::move(vals), dtype, s), std::move(shape), s)});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
array full(Shape shape, array vals, StreamOrDevice s /* = {} */) {
|
array full(Shape shape, array vals, StreamOrDevice s /* = {} */) {
|
||||||
@@ -297,12 +300,25 @@ array full(Shape shape, array vals, StreamOrDevice s /* = {} */) {
|
|||||||
return full(std::move(shape), std::move(vals), dtype, to_stream(s));
|
return full(std::move(shape), std::move(vals), dtype, to_stream(s));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array full_like(
|
||||||
|
const array& a,
|
||||||
|
array vals,
|
||||||
|
Dtype dtype,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
auto inputs = broadcast_arrays({a, std::move(vals)}, s);
|
||||||
|
return full_impl(std::move(inputs[1]), dtype, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array full_like(const array& a, array vals, StreamOrDevice s /* = {} */) {
|
||||||
|
return full_like(a, std::move(vals), a.dtype(), to_stream(s));
|
||||||
|
}
|
||||||
|
|
||||||
array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} */) {
|
array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||||
return full(shape, array(0, dtype), to_stream(s));
|
return full(shape, array(0, dtype), to_stream(s));
|
||||||
}
|
}
|
||||||
|
|
||||||
array zeros_like(const array& a, StreamOrDevice s /* = {} */) {
|
array zeros_like(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
return zeros(a.shape(), a.dtype(), to_stream(s));
|
return full_like(a, 0, a.dtype(), to_stream(s));
|
||||||
}
|
}
|
||||||
|
|
||||||
array ones(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} */) {
|
array ones(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||||
@@ -310,7 +326,7 @@ array ones(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} */) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
array ones_like(const array& a, StreamOrDevice s /* = {} */) {
|
array ones_like(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
return ones(a.shape(), a.dtype(), to_stream(s));
|
return full_like(a, 1, a.dtype(), to_stream(s));
|
||||||
}
|
}
|
||||||
|
|
||||||
array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s /* = {} */) {
|
array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||||
|
|||||||
11
mlx/ops.h
11
mlx/ops.h
@@ -69,6 +69,17 @@ array full(Shape shape, T val, StreamOrDevice s = {}) {
|
|||||||
return full(std::move(shape), array(val), to_stream(s));
|
return full(std::move(shape), array(val), to_stream(s));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array full_like(const array& a, array vals, Dtype dtype, StreamOrDevice s = {});
|
||||||
|
array full_like(const array& a, array vals, StreamOrDevice s = {});
|
||||||
|
template <typename T>
|
||||||
|
array full_like(const array& a, T val, Dtype dtype, StreamOrDevice s = {}) {
|
||||||
|
return full_like(a, array(val, dtype), dtype, to_stream(s));
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
array full_like(const array& a, T val, StreamOrDevice s = {}) {
|
||||||
|
return full_like(a, array(val, a.dtype()), to_stream(s));
|
||||||
|
}
|
||||||
|
|
||||||
/** Fill an array of the given shape with zeros. */
|
/** Fill an array of the given shape with zeros. */
|
||||||
array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s = {});
|
array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s = {});
|
||||||
inline array zeros(const Shape& shape, StreamOrDevice s = {}) {
|
inline array zeros(const Shape& shape, StreamOrDevice s = {}) {
|
||||||
|
|||||||
@@ -1146,6 +1146,7 @@ class Full : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_NAME(Full)
|
DEFINE_NAME(Full)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
};
|
};
|
||||||
|
|
||||||
class Gather : public UnaryPrimitive {
|
class Gather : public UnaryPrimitive {
|
||||||
|
|||||||
@@ -482,6 +482,28 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.assertEqual(mx.compile(fun, shapeless=True)(x).shape, (1, 32))
|
self.assertEqual(mx.compile(fun, shapeless=True)(x).shape, (1, 32))
|
||||||
|
|
||||||
|
def test_shapeless_compile_full_like(self):
|
||||||
|
x_shape = (1, 1, 32)
|
||||||
|
x = mx.zeros((x_shape))
|
||||||
|
|
||||||
|
def zeros_fun(x):
|
||||||
|
return mx.zeros_like(x)
|
||||||
|
|
||||||
|
def ones_fun(x):
|
||||||
|
return mx.ones_like(x)
|
||||||
|
|
||||||
|
compiled_zero_like = mx.compile(zeros_fun, shapeless=True)
|
||||||
|
compiled_ones_like = mx.compile(ones_fun, shapeless=True)
|
||||||
|
|
||||||
|
self.assertEqual(compiled_zero_like(x).shape, x_shape)
|
||||||
|
self.assertEqual(compiled_ones_like(x).shape, x_shape)
|
||||||
|
|
||||||
|
y_shape = (2, 2, 16)
|
||||||
|
y = mx.zeros(y_shape)
|
||||||
|
|
||||||
|
self.assertEqual(compiled_zero_like(y).shape, y_shape)
|
||||||
|
self.assertEqual(compiled_ones_like(y).shape, y_shape)
|
||||||
|
|
||||||
def test_compile_with_constant(self):
|
def test_compile_with_constant(self):
|
||||||
# Test float
|
# Test float
|
||||||
@partial(mx.compile)
|
@partial(mx.compile)
|
||||||
@@ -842,7 +864,6 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(mx.allclose(out, expected))
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
def test_compile_many_outputs(self):
|
def test_compile_many_outputs(self):
|
||||||
|
|
||||||
@mx.compile
|
@mx.compile
|
||||||
def fun(arr):
|
def fun(arr):
|
||||||
arrs = [arr] * 64
|
arrs = [arr] * 64
|
||||||
|
|||||||
@@ -2826,6 +2826,32 @@ TEST_CASE("test stack") {
|
|||||||
stack({x, y}, 0), "All arrays must have the same shape and dtype");
|
stack({x, y}, 0), "All arrays must have the same shape and dtype");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test full_like") {
|
||||||
|
auto base_int = array({1, 2, 3}, {3}, int16);
|
||||||
|
|
||||||
|
auto from_array_with_dtype = full_like(base_int, array(7.5f), float16);
|
||||||
|
auto expected_float16 = array({7.5, 7.5, 7.5}, {3}, float16);
|
||||||
|
CHECK_EQ(from_array_with_dtype.dtype(), float16);
|
||||||
|
CHECK(array_equal(from_array_with_dtype, expected_float16).item<bool>());
|
||||||
|
|
||||||
|
auto from_array_default_dtype = full_like(base_int, array(4.0f));
|
||||||
|
auto expected_int16 = array({4, 4, 4}, {3}, int16);
|
||||||
|
CHECK_EQ(from_array_default_dtype.dtype(), int16);
|
||||||
|
CHECK(array_equal(from_array_default_dtype, expected_int16).item<bool>());
|
||||||
|
|
||||||
|
auto from_scalar_with_dtype = full_like(base_int, 3.25f, float32);
|
||||||
|
auto expected_float32 = array({3.25f, 3.25f, 3.25f}, {3}, float32);
|
||||||
|
CHECK_EQ(from_scalar_with_dtype.dtype(), float32);
|
||||||
|
CHECK(array_equal(from_scalar_with_dtype, expected_float32).item<bool>());
|
||||||
|
|
||||||
|
auto base_float = array({1.0f, 2.0f}, {2}, float32);
|
||||||
|
auto from_scalar_default_dtype = full_like(base_float, 2);
|
||||||
|
auto expected_base_float = array({2.0f, 2.0f}, {2}, float32);
|
||||||
|
CHECK_EQ(from_scalar_default_dtype.dtype(), float32);
|
||||||
|
CHECK(
|
||||||
|
array_equal(from_scalar_default_dtype, expected_base_float).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
TEST_CASE("test eye") {
|
TEST_CASE("test eye") {
|
||||||
auto eye_3 = eye(3);
|
auto eye_3 = eye(3);
|
||||||
CHECK_EQ(eye_3.shape(), Shape{3, 3});
|
CHECK_EQ(eye_3.shape(), Shape{3, 3});
|
||||||
|
|||||||
Reference in New Issue
Block a user