mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Adds isinf (#445)
* adds isinf Signed-off-by: matthewfernst <matthew.f.ernst@gmail.com> * use stream + nits * typo --------- Signed-off-by: matthewfernst <matthew.f.ernst@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		| @@ -1134,12 +1134,6 @@ void init_array(py::module_& m) { | ||||
|           py::kw_only(), | ||||
|           "stream"_a = none, | ||||
|           "See :func:`any`.") | ||||
|       .def( | ||||
|           "isnan", | ||||
|           &mlx::core::isnan, | ||||
|           py::kw_only(), | ||||
|           "stream"_a = none, | ||||
|           "See :func:`isnan`.") | ||||
|       .def( | ||||
|           "moveaxis", | ||||
|           &moveaxis, | ||||
|   | ||||
| @@ -1836,7 +1836,25 @@ void init_ops(py::module_& m) { | ||||
|             a (array): Input array. | ||||
|  | ||||
|         Returns: | ||||
|             array: The array with boolean values indicating which elements are NaN. | ||||
|             array: The boolean array indicating which elements are NaN. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "isinf", | ||||
|       &mlx::core::isinf, | ||||
|       "a"_a, | ||||
|       py::pos_only(), | ||||
|       py::kw_only(), | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         isinf(a: array, stream: Union[None, Stream, Device] = None) -> array | ||||
|  | ||||
|         Return a boolean array indicating which elements are +/- inifnity. | ||||
|  | ||||
|         Args: | ||||
|             a (array): Input array. | ||||
|  | ||||
|         Returns: | ||||
|             array: The boolean array indicating which elements are +/- infinity. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "moveaxis", | ||||
|   | ||||
| @@ -336,6 +336,21 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|  | ||||
|         self.assertEqual(mx.isnan(0 * mx.array(float("inf"))).tolist(), True) | ||||
|  | ||||
|     def test_isinf(self): | ||||
|         x = mx.array([0.0, float("inf")]) | ||||
|         self.assertEqual(mx.isinf(x).tolist(), [False, True]) | ||||
|  | ||||
|         x = mx.array([0.0, float("inf")]).astype(mx.float16) | ||||
|         self.assertEqual(mx.isinf(x).tolist(), [False, True]) | ||||
|  | ||||
|         x = mx.array([0.0, float("inf")]).astype(mx.bfloat16) | ||||
|         self.assertEqual(mx.isinf(x).tolist(), [False, True]) | ||||
|  | ||||
|         x = mx.array([0.0, float("inf")]).astype(mx.complex64) | ||||
|         self.assertEqual(mx.isinf(x).tolist(), [False, True]) | ||||
|  | ||||
|         self.assertEqual(mx.isinf(0 * mx.array(float("inf"))).tolist(), False) | ||||
|  | ||||
|     def test_tri(self): | ||||
|         for shape in [[4], [4, 4], [2, 10]]: | ||||
|             for diag in [-1, 0, 1, -2]: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Matthew Ernst
					Matthew Ernst