| 
									
										
										
										
											2023-11-30 11:12:53 -08:00
										 |  |  | // Copyright © 2023 Apple Inc.
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 10:42:59 -08:00
										 |  |  | #include "doctest/doctest.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "mlx/mlx.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | using namespace mlx::core; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | TEST_CASE("test arange") { | 
					
						
							| 
									
										
										
										
											2024-01-02 00:08:17 -05:00
										 |  |  |   // Check type is inferred correctly
 | 
					
						
							| 
									
										
										
										
											2023-11-29 10:42:59 -08:00
										 |  |  |   { | 
					
						
							|  |  |  |     auto x = arange(10); | 
					
						
							|  |  |  |     CHECK_EQ(x.dtype(), int32); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = arange(10.0); | 
					
						
							|  |  |  |     CHECK_EQ(x.dtype(), float32); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = arange(10, float32); | 
					
						
							|  |  |  |     CHECK_EQ(x.dtype(), float32); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-21 23:33:43 +01:00
										 |  |  |     x = arange(10, float16); | 
					
						
							|  |  |  |     CHECK_EQ(x.dtype(), float16); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = arange(10, bfloat16); | 
					
						
							|  |  |  |     CHECK_EQ(x.dtype(), bfloat16); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 10:42:59 -08:00
										 |  |  |     x = arange(10.0, int32); | 
					
						
							|  |  |  |     CHECK_EQ(x.dtype(), int32); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = arange(0, 10); | 
					
						
							|  |  |  |     CHECK_EQ(x.dtype(), int32); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = arange(0.0, 10.0, int32); | 
					
						
							|  |  |  |     CHECK_EQ(x.dtype(), int32); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = arange(0.0, 10.0); | 
					
						
							|  |  |  |     CHECK_EQ(x.dtype(), float32); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = arange(0, 10, float32); | 
					
						
							|  |  |  |     CHECK_EQ(x.dtype(), float32); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = arange(0, 10, 0.1, float32); | 
					
						
							|  |  |  |     CHECK_EQ(x.dtype(), float32); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = arange(0.0, 10.0, 0.5, int32); | 
					
						
							|  |  |  |     CHECK_EQ(x.dtype(), int32); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = arange(10.0, uint32); | 
					
						
							|  |  |  |     CHECK_EQ(x.dtype(), uint32); | 
					
						
							|  |  |  |     x = arange(0.0, 10.0, uint32); | 
					
						
							|  |  |  |     CHECK_EQ(x.dtype(), uint32); | 
					
						
							|  |  |  |     x = arange(0.0, 10.0, 0.5, uint32); | 
					
						
							|  |  |  |     CHECK_EQ(x.dtype(), uint32); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // arange unsupported for bool_
 | 
					
						
							|  |  |  |     CHECK_THROWS_AS(arange(10, bool_), std::invalid_argument); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Check correct sizes
 | 
					
						
							|  |  |  |   { | 
					
						
							|  |  |  |     auto x = arange(10); | 
					
						
							|  |  |  |     CHECK_EQ(x.size(), 10); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = arange(0.0, 10.0, 0.5); | 
					
						
							|  |  |  |     CHECK_EQ(x.size(), 20); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = arange(0.0, 10.0, 0.45); | 
					
						
							|  |  |  |     CHECK_EQ(x.size(), 23); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = arange(0, 10, 10); | 
					
						
							|  |  |  |     CHECK_EQ(x.size(), 1); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = arange(0, 10, 9); | 
					
						
							|  |  |  |     CHECK_EQ(x.size(), 2); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = arange(0, 10, 100); | 
					
						
							|  |  |  |     CHECK_EQ(x.size(), 1); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = arange(0, -10, 1); | 
					
						
							|  |  |  |     CHECK_EQ(x.size(), 0); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = arange(0, -10, -1); | 
					
						
							|  |  |  |     CHECK_EQ(x.size(), 10); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = arange(0, -10, -10); | 
					
						
							|  |  |  |     CHECK_EQ(x.size(), 1); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Check values
 | 
					
						
							|  |  |  |   { | 
					
						
							|  |  |  |     auto x = arange(0, 3); | 
					
						
							|  |  |  |     CHECK(array_equal(x, array({0, 1, 2})).item<bool>()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = arange(0, 3, 2); | 
					
						
							|  |  |  |     CHECK(array_equal(x, array({0, 2})).item<bool>()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = arange(0, 3, 3); | 
					
						
							|  |  |  |     CHECK(array_equal(x, array({0})).item<bool>()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = arange(0, -3, 1); | 
					
						
							|  |  |  |     CHECK(array_equal(x, array({})).item<bool>()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = arange(0, 3, -1); | 
					
						
							|  |  |  |     CHECK(array_equal(x, array({})).item<bool>()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = arange(0, -3, -1); | 
					
						
							|  |  |  |     CHECK(array_equal(x, array({0, -1, -2})).item<bool>()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = arange(0.0, 5.0, 0.5, int32); | 
					
						
							|  |  |  |     CHECK(array_equal(x, zeros({10})).item<bool>()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = arange(0.0, 5.0, 1.5, int32); | 
					
						
							|  |  |  |     CHECK(array_equal(x, array({0, 1, 2, 3})).item<bool>()); | 
					
						
							| 
									
										
										
										
											2023-12-21 23:33:43 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     x = arange(0.0, 5.0, 1.0, float16); | 
					
						
							|  |  |  |     CHECK(array_equal(x, array({0, 1, 2, 3, 4}, float16)).item<bool>()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = arange(0.0, 5.0, 1.0, bfloat16); | 
					
						
							|  |  |  |     CHECK(array_equal(x, array({0, 1, 2, 3, 4}, bfloat16)).item<bool>()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = arange(0.0, 5.0, 1.5, bfloat16); | 
					
						
							|  |  |  |     CHECK(array_equal(x, array({0., 1.5, 3., 4.5}, bfloat16)).item<bool>()); | 
					
						
							| 
									
										
										
										
											2023-11-29 10:42:59 -08:00
										 |  |  |   } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | TEST_CASE("test astype") { | 
					
						
							|  |  |  |   // Check type conversions
 | 
					
						
							|  |  |  |   { | 
					
						
							|  |  |  |     auto x = array(1); | 
					
						
							|  |  |  |     auto y = astype(x, float32); | 
					
						
							|  |  |  |     CHECK_EQ(y.dtype(), float32); | 
					
						
							|  |  |  |     CHECK_EQ(y.item<float>(), 1.0f); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     y = astype(x, int32); | 
					
						
							|  |  |  |     CHECK_EQ(y.dtype(), int32); | 
					
						
							|  |  |  |     CHECK_EQ(y.item<int>(), 1); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = array(-3.0f); | 
					
						
							|  |  |  |     y = astype(x, int32); | 
					
						
							|  |  |  |     CHECK_EQ(y.dtype(), int32); | 
					
						
							|  |  |  |     CHECK_EQ(y.item<int>(), -3); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | TEST_CASE("test full") { | 
					
						
							| 
									
										
										
										
											2024-01-29 11:19:38 -08:00
										 |  |  |   // Check throws on bad shape
 | 
					
						
							|  |  |  |   { | 
					
						
							|  |  |  |     CHECK_THROWS(full({-5, 0}, 0)); | 
					
						
							|  |  |  |     CHECK_THROWS(full({0, -5}, 0)); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 10:42:59 -08:00
										 |  |  |   // Check full works for different types
 | 
					
						
							|  |  |  |   { | 
					
						
							|  |  |  |     auto x = full({}, 0); | 
					
						
							|  |  |  |     CHECK_EQ(x.dtype(), int32); | 
					
						
							|  |  |  |     CHECK_EQ(x.item<int>(), 0); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = full({}, 0.0); | 
					
						
							|  |  |  |     CHECK_EQ(x.dtype(), float32); | 
					
						
							|  |  |  |     CHECK_EQ(x.item<float>(), 0); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = full({}, false); | 
					
						
							|  |  |  |     CHECK_EQ(x.item<bool>(), false); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = full({}, 0, int32); | 
					
						
							|  |  |  |     CHECK_EQ(x.item<int>(), 0); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = full({}, 0, float32); | 
					
						
							|  |  |  |     CHECK_EQ(x.item<float>(), 0); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = full({1, 2}, 2, float32); | 
					
						
							|  |  |  |     CHECK(array_equal(x, array({2.0, 2.0}, {1, 2})).item<bool>()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = full({2, 1}, 2, float32); | 
					
						
							|  |  |  |     CHECK(array_equal(x, array({2.0, 2.0}, {2, 1})).item<bool>()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = full({2}, false); | 
					
						
							|  |  |  |     CHECK_EQ(x.dtype(), bool_); | 
					
						
							|  |  |  |     CHECK(array_equal(x, array({false, false})).item<bool>()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = full({2}, 1.0, bool_); | 
					
						
							|  |  |  |     CHECK_EQ(x.dtype(), bool_); | 
					
						
							|  |  |  |     CHECK(array_equal(x, array({true, true})).item<bool>()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = full({2}, 1.0, uint32); | 
					
						
							|  |  |  |     CHECK_EQ(x.dtype(), uint32); | 
					
						
							|  |  |  |     CHECK(array_equal(x, array({1, 1})).item<bool>()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     CHECK_THROWS_AS(full({2}, array({})), std::invalid_argument); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Check broadcasting works
 | 
					
						
							|  |  |  |   { | 
					
						
							|  |  |  |     auto x = full({2, 2}, array({3, 4}, {2, 1})); | 
					
						
							|  |  |  |     CHECK(array_equal(x, array({3, 3, 4, 4}, {2, 2})).item<bool>()); | 
					
						
							|  |  |  |     x = full({2, 2}, array({3, 4}, {1, 2})); | 
					
						
							|  |  |  |     CHECK(array_equal(x, array({3, 4, 3, 4}, {2, 2})).item<bool>()); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Check zeros and ones
 | 
					
						
							|  |  |  |   { | 
					
						
							|  |  |  |     auto x = zeros({2, 2}, float32); | 
					
						
							| 
									
										
										
										
											2024-12-09 11:09:02 -08:00
										 |  |  |     CHECK_EQ(x.shape(), Shape{2, 2}); | 
					
						
							| 
									
										
										
										
											2023-11-29 10:42:59 -08:00
										 |  |  |     CHECK_EQ(x.ndim(), 2); | 
					
						
							|  |  |  |     CHECK_EQ(x.dtype(), float32); | 
					
						
							|  |  |  |     auto y = array({0.0, 0.0, 0.0, 0.0}, {2, 2}); | 
					
						
							|  |  |  |     CHECK(array_equal(x, y).item<bool>()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = ones({2, 2}, float32); | 
					
						
							| 
									
										
										
										
											2024-12-09 11:09:02 -08:00
										 |  |  |     CHECK_EQ(x.shape(), Shape{2, 2}); | 
					
						
							| 
									
										
										
										
											2023-11-29 10:42:59 -08:00
										 |  |  |     CHECK_EQ(x.ndim(), 2); | 
					
						
							|  |  |  |     CHECK_EQ(x.dtype(), float32); | 
					
						
							|  |  |  |     y = array({1.0, 1.0, 1.0, 1.0}, {2, 2}); | 
					
						
							|  |  |  |     CHECK(array_equal(x, y).item<bool>()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = zeros({2, 2}, int32); | 
					
						
							|  |  |  |     y = zeros_like(x); | 
					
						
							|  |  |  |     CHECK_EQ(y.dtype(), int32); | 
					
						
							|  |  |  |     CHECK(array_equal(x, y).item<bool>()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = ones({2, 2}, int32); | 
					
						
							|  |  |  |     y = ones_like(x); | 
					
						
							|  |  |  |     CHECK_EQ(y.dtype(), int32); | 
					
						
							|  |  |  |     CHECK(array_equal(x, y).item<bool>()); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Works for empty shape and empty array
 | 
					
						
							|  |  |  |   { | 
					
						
							|  |  |  |     array x = ones({}, int32); | 
					
						
							| 
									
										
										
										
											2024-12-09 11:09:02 -08:00
										 |  |  |     CHECK_EQ(x.shape(), Shape{}); | 
					
						
							| 
									
										
										
										
											2023-11-29 10:42:59 -08:00
										 |  |  |     CHECK_EQ(x.item<int>(), 1); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x = full({0}, array({})); | 
					
						
							| 
									
										
										
										
											2024-12-09 11:09:02 -08:00
										 |  |  |     CHECK_EQ(x.shape(), Shape{0}); | 
					
						
							| 
									
										
										
										
											2023-11-29 10:42:59 -08:00
										 |  |  |     CHECK_EQ(x.size(), 0); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     CHECK_THROWS_AS(full({}, array({})), std::invalid_argument); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | } |