mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 17:39:05 +08:00
Shapeless support for zeros/ones_like (#2726)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
* shapeless support for zeros/ones_like * Improvements * fix access after moved
This commit is contained in:
@@ -2826,6 +2826,32 @@ TEST_CASE("test stack") {
|
||||
stack({x, y}, 0), "All arrays must have the same shape and dtype");
|
||||
}
|
||||
|
||||
TEST_CASE("test full_like") {
|
||||
auto base_int = array({1, 2, 3}, {3}, int16);
|
||||
|
||||
auto from_array_with_dtype = full_like(base_int, array(7.5f), float16);
|
||||
auto expected_float16 = array({7.5, 7.5, 7.5}, {3}, float16);
|
||||
CHECK_EQ(from_array_with_dtype.dtype(), float16);
|
||||
CHECK(array_equal(from_array_with_dtype, expected_float16).item<bool>());
|
||||
|
||||
auto from_array_default_dtype = full_like(base_int, array(4.0f));
|
||||
auto expected_int16 = array({4, 4, 4}, {3}, int16);
|
||||
CHECK_EQ(from_array_default_dtype.dtype(), int16);
|
||||
CHECK(array_equal(from_array_default_dtype, expected_int16).item<bool>());
|
||||
|
||||
auto from_scalar_with_dtype = full_like(base_int, 3.25f, float32);
|
||||
auto expected_float32 = array({3.25f, 3.25f, 3.25f}, {3}, float32);
|
||||
CHECK_EQ(from_scalar_with_dtype.dtype(), float32);
|
||||
CHECK(array_equal(from_scalar_with_dtype, expected_float32).item<bool>());
|
||||
|
||||
auto base_float = array({1.0f, 2.0f}, {2}, float32);
|
||||
auto from_scalar_default_dtype = full_like(base_float, 2);
|
||||
auto expected_base_float = array({2.0f, 2.0f}, {2}, float32);
|
||||
CHECK_EQ(from_scalar_default_dtype.dtype(), float32);
|
||||
CHECK(
|
||||
array_equal(from_scalar_default_dtype, expected_base_float).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test eye") {
|
||||
auto eye_3 = eye(3);
|
||||
CHECK_EQ(eye_3.shape(), Shape{3, 3});
|
||||
|
||||
Reference in New Issue
Block a user