diff --git a/docs/src/usage/indexing.rst b/docs/src/usage/indexing.rst index c74e357fae..dcbc84c1b9 100644 --- a/docs/src/usage/indexing.rst +++ b/docs/src/usage/indexing.rst @@ -107,6 +107,16 @@ same array: >>> a array([1, 2, 0], dtype=int32) + +Note, unlike NumPy, updates to the same location are nondeterministic: + +.. code-block:: shell + + >>> a = mx.array([1, 2, 3]) + >>> a[[0, 0]] = mx.array([4, 5]) + +The first element of ``a`` could be ``4`` or ``5``. + Transformations of functions which use in-place updates are allowed and work as expected. For example: diff --git a/mlx/backend/cuda/device/unary_ops.cuh b/mlx/backend/cuda/device/unary_ops.cuh index af7c30e64b..efa9133b1b 100644 --- a/mlx/backend/cuda/device/unary_ops.cuh +++ b/mlx/backend/cuda/device/unary_ops.cuh @@ -5,6 +5,8 @@ #include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/device/utils.cuh" +#include + namespace mlx::core::cu { struct Abs { @@ -183,21 +185,38 @@ struct Imag { struct Log { template __device__ T operator()(T x) { - return log(x); + if constexpr (cuda::std::is_same_v) { + auto r = log(cuCrealf(Abs{}(x))); + auto i = atan2f(cuCimagf(x), cuCrealf(x)); + return {r, i}; + } else { + return log(x); + } } }; struct Log2 { template __device__ T operator()(T x) { - return log2(x); + if constexpr (cuda::std::is_same_v) { + auto y = Log{}(x); + return {cuCrealf(y) / CUDART_LN2_F, cuCimagf(y) / CUDART_LN2_F}; + } else { + return log2(x); + } } }; struct Log10 { template __device__ T operator()(T x) { - return log10(x); + if constexpr (cuda::std::is_same_v) { + auto y = Log{}(x); + return {cuCrealf(y) / CUDART_LNT_F, cuCimagf(y) / CUDART_LNT_F}; + return y; + } else { + return log10(x); + } } }; diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index 05febaa5c9..59b48f886e 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -102,6 +102,11 @@ inline constexpr bool is_floating_v = cuda::std::is_same_v || cuda::std::is_same_v || cuda::std::is_same_v || cuda::std::is_same_v; +// Type traits for detecting complex or real floating point numbers. +template +inline constexpr bool is_inexact_v = + is_floating_v || cuda::std::is_same_v; + // Utility to copy data from vector to array in host. template inline cuda::std::array const_param(const std::vector& vec) { diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index f02898705d..1cff4665e9 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -28,11 +28,14 @@ constexpr bool supports_unary_op() { std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || std::is_same_v) { return std::is_same_v && is_floating_v; } + if (std::is_same_v || std::is_same_v || + std::is_same_v) { + return std::is_same_v && is_inexact_v; + } if (std::is_same_v) { return std::is_same_v && std::is_integral_v && !std::is_same_v; diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index ae3a32afcd..0072db192b 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -1,6 +1,5 @@ cuda_skip = { "TestArray.test_api", - "TestArray.test_setitem", "TestAutograd.test_cumprod_grad", "TestAutograd.test_slice_grads", "TestAutograd.test_split_against_slice", @@ -51,7 +50,6 @@ cuda_skip = { "TestEinsum.test_opt_einsum_test_cases", "TestEval.test_multi_output_eval_during_transform", "TestExportImport.test_export_conv", - "TestFast.test_rope_grad", "TestFFT.test_fft", "TestFFT.test_fft_big_powers_of_two", "TestFFT.test_fft_contiguity", @@ -89,9 +87,6 @@ cuda_skip = { "TestOps.test_argpartition", "TestOps.test_array_equal", "TestOps.test_as_strided", - "TestOps.test_atleast_1d", - "TestOps.test_atleast_2d", - "TestOps.test_atleast_3d", "TestOps.test_binary_ops", "TestOps.test_bitwise_grad", "TestOps.test_complex_ops", @@ -101,15 +96,11 @@ cuda_skip = { "TestOps.test_hadamard_grad_vmap", "TestOps.test_irregular_binary_ops", "TestOps.test_kron", - "TestOps.test_log", - "TestOps.test_log10", "TestOps.test_log1p", - "TestOps.test_log2", "TestOps.test_logaddexp", "TestOps.test_logcumsumexp", "TestOps.test_partition", "TestOps.test_scans", - "TestOps.test_slice_update_reversed", "TestOps.test_softmax", "TestOps.test_sort", "TestOps.test_tensordot", diff --git a/python/tests/test_array.py b/python/tests/test_array.py index c02b524b40..3ab41bef74 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1187,7 +1187,7 @@ class TestArray(mlx_tests.MLXTestCase): check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1])) check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1])) check_slices( - np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 0, 1]) + np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 2, 1]) ) # Multiple slices diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 02ada39b4d..8521d8f807 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2586,17 +2586,6 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqualArray(result, mx.array(expected)) def test_atleast_1d(self): - def compare_nested_lists(x, y): - if isinstance(x, list) and isinstance(y, list): - if len(x) != len(y): - return False - for i in range(len(x)): - if not compare_nested_lists(x[i], y[i]): - return False - return True - else: - return x == y - # Test 1D input arrays = [ [1], @@ -2614,23 +2603,11 @@ class TestOps(mlx_tests.MLXTestCase): for i, array in enumerate(arrays): mx_res = mx.atleast_1d(mx.array(array)) np_res = np.atleast_1d(np.array(array)) - self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) self.assertEqual(mx_res.shape, np_res.shape) self.assertEqual(mx_res.ndim, np_res.ndim) - self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i]))) + self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i])) def test_atleast_2d(self): - def compare_nested_lists(x, y): - if isinstance(x, list) and isinstance(y, list): - if len(x) != len(y): - return False - for i in range(len(x)): - if not compare_nested_lists(x[i], y[i]): - return False - return True - else: - return x == y - # Test 1D input arrays = [ [1], @@ -2648,23 +2625,11 @@ class TestOps(mlx_tests.MLXTestCase): for i, array in enumerate(arrays): mx_res = mx.atleast_2d(mx.array(array)) np_res = np.atleast_2d(np.array(array)) - self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) self.assertEqual(mx_res.shape, np_res.shape) self.assertEqual(mx_res.ndim, np_res.ndim) - self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i]))) + self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i])) def test_atleast_3d(self): - def compare_nested_lists(x, y): - if isinstance(x, list) and isinstance(y, list): - if len(x) != len(y): - return False - for i in range(len(x)): - if not compare_nested_lists(x[i], y[i]): - return False - return True - else: - return x == y - # Test 1D input arrays = [ [1], @@ -2682,10 +2647,9 @@ class TestOps(mlx_tests.MLXTestCase): for i, array in enumerate(arrays): mx_res = mx.atleast_3d(mx.array(array)) np_res = np.atleast_3d(np.array(array)) - self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) self.assertEqual(mx_res.shape, np_res.shape) self.assertEqual(mx_res.ndim, np_res.ndim) - self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i]))) + self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i])) def test_issubdtype(self): self.assertTrue(mx.issubdtype(mx.bfloat16, mx.inexact))