diff --git a/docs/src/usage/indexing.rst b/docs/src/usage/indexing.rst index c74e357fa..dcbc84c1b 100644 --- a/docs/src/usage/indexing.rst +++ b/docs/src/usage/indexing.rst @@ -107,6 +107,16 @@ same array: >>> a array([1, 2, 0], dtype=int32) + +Note, unlike NumPy, updates to the same location are nondeterministic: + +.. code-block:: shell + + >>> a = mx.array([1, 2, 3]) + >>> a[[0, 0]] = mx.array([4, 5]) + +The first element of ``a`` could be ``4`` or ``5``. + Transformations of functions which use in-place updates are allowed and work as expected. For example: diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index 47efc44d2..d4df06f18 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -165,7 +165,7 @@ void binary_op_gpu_inplace( a.data(), b.data(), out.data(), - out.data_size(), + out.size(), const_param(shape), const_param(a_strides), const_param(b_strides)); @@ -178,7 +178,7 @@ void binary_op_gpu_inplace( a.data(), b.data(), out.data(), - out.data_size(), + out.size(), const_param(shape), const_param(a_strides), const_param(b_strides), @@ -196,8 +196,8 @@ void binary_op_gpu_inplace( } else if (bopt == BinaryOpType::VectorVector) { kernel = cu::binary_vv; } - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, LARGE); + auto [num_blocks, block_dims] = get_launch_args( + kernel, out.data_size(), out.shape(), out.strides(), LARGE); kernel<<>>( a.data(), b.data(), @@ -264,7 +264,6 @@ BINARY_GPU(Add) BINARY_GPU(ArcTan2) BINARY_GPU(Divide) BINARY_GPU(Remainder) -BINARY_GPU(Equal) BINARY_GPU(Greater) BINARY_GPU(GreaterEqual) BINARY_GPU(Less) @@ -279,6 +278,17 @@ BINARY_GPU(NotEqual) BINARY_GPU(Power) BINARY_GPU(Subtract) +void Equal::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Equal::eval_gpu"); + auto& s = out.primitive().stream(); + auto op = get_primitive_string(this); + if (equal_nan_) { + binary_op_gpu(inputs, out, op, s); + } else { + binary_op_gpu(inputs, out, op, s); + } +} + void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("BitwiseBinary::eval_gpu"); auto& s = out.primitive().stream(); diff --git a/mlx/backend/cuda/copy.cu b/mlx/backend/cuda/copy.cu index 8649e1bf9..817860d0a 100644 --- a/mlx/backend/cuda/copy.cu +++ b/mlx/backend/cuda/copy.cu @@ -6,7 +6,7 @@ namespace mlx::core { void copy_gpu_inplace( - const array& in_, + const array& in, array& out, const Shape& shape, const Strides& strides_in, @@ -20,7 +20,6 @@ void copy_gpu_inplace( if (out.size() == 0) { return; } - const array& in = in_.data_shared_ptr() ? in_ : out; auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(in); diff --git a/mlx/backend/cuda/copy/copy.cuh b/mlx/backend/cuda/copy/copy.cuh index 0c1eff774..789826507 100644 --- a/mlx/backend/cuda/copy/copy.cuh +++ b/mlx/backend/cuda/copy/copy.cuh @@ -10,20 +10,13 @@ namespace mlx::core { -#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \ - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \ - MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \ - using InType = cuda_type_t; \ - using OutType = cuda_type_t; \ - if constexpr (cu::CastOp::is_castable) { \ - __VA_ARGS__; \ - } else { \ - throw std::runtime_error(fmt::format( \ - "Can not copy data from dtype {} to {}.", \ - dtype_to_string(out.dtype()), \ - dtype_to_string(in.dtype()))); \ - } \ - }); \ +#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \ + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \ + MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \ + using InType = cuda_type_t; \ + using OutType = cuda_type_t; \ + __VA_ARGS__; \ + }); \ }) void copy_contiguous( diff --git a/mlx/backend/cuda/copy/copy_contiguous.cu b/mlx/backend/cuda/copy/copy_contiguous.cu index fa79f0604..5f4c9ca8f 100644 --- a/mlx/backend/cuda/copy/copy_contiguous.cu +++ b/mlx/backend/cuda/copy/copy_contiguous.cu @@ -43,7 +43,8 @@ void copy_contiguous( if (ctype == CopyType::Vector) { kernel = cu::copy_v; } - auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE); + auto [num_blocks, block_dims] = get_launch_args( + kernel, out.data_size(), out.shape(), out.strides(), LARGE); kernel<<>>( in.data() + in_offset, out.data() + out_offset, diff --git a/mlx/backend/cuda/copy/copy_general.cu b/mlx/backend/cuda/copy/copy_general.cu index 3c5b3bbb3..9f50c8a31 100644 --- a/mlx/backend/cuda/copy/copy_general.cu +++ b/mlx/backend/cuda/copy/copy_general.cu @@ -59,9 +59,9 @@ void copy_general( MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { const InType* in_ptr = in.data() + offset_in; OutType* out_ptr = out.data() + offset_out; - bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX; + bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; + using IdxT = std::conditional_t; int ndim = shape.size(); if (ndim <= 3) { MLX_SWITCH_1_2_3(ndim, NDIM, { @@ -70,7 +70,7 @@ void copy_general( kernel<<>>( in_ptr, out_ptr, - out.data_size(), + out.size(), const_param(shape), const_param(strides_in), const_param(strides_out)); @@ -81,7 +81,7 @@ void copy_general( kernel<<>>( in_ptr, out_ptr, - out.data_size(), + out.size(), const_param(shape), const_param(strides_in), const_param(strides_out), diff --git a/mlx/backend/cuda/copy/copy_general_dynamic.cu b/mlx/backend/cuda/copy/copy_general_dynamic.cu index b9774662a..2e1cf4fba 100644 --- a/mlx/backend/cuda/copy/copy_general_dynamic.cu +++ b/mlx/backend/cuda/copy/copy_general_dynamic.cu @@ -65,9 +65,9 @@ void copy_general_dynamic( MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { const InType* in_ptr = in.data() + offset_in; OutType* out_ptr = out.data() + offset_out; - bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX; + bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; + using IdxT = std::conditional_t; int ndim = shape.size(); if (ndim <= 3) { MLX_SWITCH_1_2_3(ndim, NDIM, { @@ -76,7 +76,7 @@ void copy_general_dynamic( kernel<<>>( in_ptr, out_ptr, - out.data_size(), + out.size(), const_param(shape), const_param(strides_in), const_param(strides_out), @@ -89,7 +89,7 @@ void copy_general_dynamic( kernel<<>>( in_ptr, out_ptr, - out.data_size(), + out.size(), const_param(shape), const_param(strides_in), const_param(strides_out), diff --git a/mlx/backend/cuda/copy/copy_general_input.cu b/mlx/backend/cuda/copy/copy_general_input.cu index 4f2784927..a3bb37e53 100644 --- a/mlx/backend/cuda/copy/copy_general_input.cu +++ b/mlx/backend/cuda/copy/copy_general_input.cu @@ -54,9 +54,9 @@ void copy_general_input( MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { const InType* in_ptr = in.data() + offset_in; OutType* out_ptr = out.data() + offset_out; - bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX; + bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; + using IdxT = std::conditional_t; int ndim = shape.size(); if (ndim <= 3) { MLX_SWITCH_1_2_3(ndim, NDIM, { @@ -65,7 +65,7 @@ void copy_general_input( kernel<<>>( in_ptr, out_ptr, - out.data_size(), + out.size(), const_param(shape), const_param(strides_in)); }); @@ -75,7 +75,7 @@ void copy_general_input( kernel<<>>( in_ptr, out_ptr, - out.data_size(), + out.size(), const_param(shape), const_param(strides_in), ndim); diff --git a/mlx/backend/cuda/device/cast_op.cuh b/mlx/backend/cuda/device/cast_op.cuh index 30b44d46f..f15270432 100644 --- a/mlx/backend/cuda/device/cast_op.cuh +++ b/mlx/backend/cuda/device/cast_op.cuh @@ -45,6 +45,18 @@ struct CastOp< } }; +template +struct CastOp< + SrcT, + DstT, + cuda::std::enable_if_t>> { + static constexpr bool is_castable = true; + + __device__ SrcT operator()(SrcT x) { + return x; + } +}; + // Return an iterator that cast the value to DstT using CastOp. template __host__ __device__ auto make_cast_iterator(Iterator it) { diff --git a/mlx/backend/cuda/device/unary_ops.cuh b/mlx/backend/cuda/device/unary_ops.cuh index af7c30e64..efa9133b1 100644 --- a/mlx/backend/cuda/device/unary_ops.cuh +++ b/mlx/backend/cuda/device/unary_ops.cuh @@ -5,6 +5,8 @@ #include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/device/utils.cuh" +#include + namespace mlx::core::cu { struct Abs { @@ -183,21 +185,38 @@ struct Imag { struct Log { template __device__ T operator()(T x) { - return log(x); + if constexpr (cuda::std::is_same_v) { + auto r = log(cuCrealf(Abs{}(x))); + auto i = atan2f(cuCimagf(x), cuCrealf(x)); + return {r, i}; + } else { + return log(x); + } } }; struct Log2 { template __device__ T operator()(T x) { - return log2(x); + if constexpr (cuda::std::is_same_v) { + auto y = Log{}(x); + return {cuCrealf(y) / CUDART_LN2_F, cuCimagf(y) / CUDART_LN2_F}; + } else { + return log2(x); + } } }; struct Log10 { template __device__ T operator()(T x) { - return log10(x); + if constexpr (cuda::std::is_same_v) { + auto y = Log{}(x); + return {cuCrealf(y) / CUDART_LNT_F, cuCimagf(y) / CUDART_LNT_F}; + return y; + } else { + return log10(x); + } } }; diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index 84392a1ec..b1fe875bd 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -102,6 +102,11 @@ inline constexpr bool is_floating_v = cuda::std::is_same_v || cuda::std::is_same_v || cuda::std::is_same_v || cuda::std::is_same_v; +// Type traits for detecting complex or real floating point numbers. +template +inline constexpr bool is_inexact_v = + is_floating_v || cuda::std::is_same_v; + // Utility to copy data from vector to array in host. template inline cuda::std::array const_param(const std::vector& vec) { @@ -136,17 +141,19 @@ inline uint max_occupancy_block_dim(T kernel) { template inline std::tuple get_launch_args( T kernel, - const array& arr, + size_t size, + const Shape& shape, + const Strides& strides, bool large, int work_per_thread = 1) { - size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread); + size_t nthreads = cuda::ceil_div(size, work_per_thread); uint block_dim = max_occupancy_block_dim(kernel); if (block_dim > nthreads) { block_dim = nthreads; } dim3 num_blocks; if (large) { - num_blocks = get_2d_grid_dims(arr.shape(), arr.strides(), work_per_thread); + num_blocks = get_2d_grid_dims(shape, strides, work_per_thread); num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim); } else { num_blocks.x = cuda::ceil_div(nthreads, block_dim); @@ -154,4 +161,14 @@ inline std::tuple get_launch_args( return std::make_tuple(num_blocks, block_dim); } +template +inline std::tuple get_launch_args( + T kernel, + const array& arr, + bool large, + int work_per_thread = 1) { + return get_launch_args( + kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index bb79d4249..02e46afc1 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -116,7 +116,7 @@ void ternary_op_gpu_inplace( b.data(), c.data(), out.data(), - out.data_size(), + out.size(), const_param(shape), const_param(a_strides), const_param(b_strides), @@ -142,7 +142,8 @@ void ternary_op_gpu_inplace( MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { using IdxT = std::conditional_t; auto kernel = cu::ternary_v; - auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE); + auto [num_blocks, block_dims] = get_launch_args( + kernel, out.data_size(), out.shape(), out.strides(), LARGE); kernel<<>>( a.data(), b.data(), diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index f9d373455..d2fa96381 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -28,11 +28,14 @@ constexpr bool supports_unary_op() { std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) { + std::is_same_v || std::is_same_v || + std::is_same_v) { return std::is_same_v && is_floating_v; } + if (std::is_same_v || std::is_same_v || + std::is_same_v) { + return std::is_same_v && is_inexact_v; + } if (std::is_same_v) { return std::is_same_v && std::is_integral_v && !std::is_same_v; @@ -91,7 +94,7 @@ void unary_op_gpu_inplace( } else { auto [shape, strides] = collapse_contiguous_dims(in); auto [in_begin, in_end] = cu::make_general_iterators( - in_ptr, in.data_size(), shape, strides); + in_ptr, in.size(), shape, strides); thrust::transform(policy, in_begin, in_end, out_ptr, Op()); } } else { diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index cda396dcb..0072db192 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -1,6 +1,5 @@ cuda_skip = { "TestArray.test_api", - "TestArray.test_setitem", "TestAutograd.test_cumprod_grad", "TestAutograd.test_slice_grads", "TestAutograd.test_split_against_slice", @@ -51,7 +50,6 @@ cuda_skip = { "TestEinsum.test_opt_einsum_test_cases", "TestEval.test_multi_output_eval_during_transform", "TestExportImport.test_export_conv", - "TestFast.test_rope_grad", "TestFFT.test_fft", "TestFFT.test_fft_big_powers_of_two", "TestFFT.test_fft_contiguity", @@ -89,9 +87,6 @@ cuda_skip = { "TestOps.test_argpartition", "TestOps.test_array_equal", "TestOps.test_as_strided", - "TestOps.test_atleast_1d", - "TestOps.test_atleast_2d", - "TestOps.test_atleast_3d", "TestOps.test_binary_ops", "TestOps.test_bitwise_grad", "TestOps.test_complex_ops", @@ -100,22 +95,16 @@ cuda_skip = { "TestOps.test_hadamard", "TestOps.test_hadamard_grad_vmap", "TestOps.test_irregular_binary_ops", - "TestOps.test_isfinite", "TestOps.test_kron", - "TestOps.test_log", - "TestOps.test_log10", "TestOps.test_log1p", - "TestOps.test_log2", "TestOps.test_logaddexp", "TestOps.test_logcumsumexp", "TestOps.test_partition", "TestOps.test_scans", - "TestOps.test_slice_update_reversed", "TestOps.test_softmax", "TestOps.test_sort", "TestOps.test_tensordot", "TestOps.test_tile", - "TestOps.test_view", "TestQuantized.test_gather_matmul_grad", "TestQuantized.test_gather_qmm", "TestQuantized.test_gather_qmm_sorted", @@ -136,7 +125,6 @@ cuda_skip = { "TestReduce.test_expand_sums", "TestReduce.test_many_reduction_axes", "TestUpsample.test_torch_upsample", - "TestVmap.test_unary", "TestVmap.test_vmap_conv", "TestVmap.test_vmap_inverse", "TestVmap.test_vmap_svd", diff --git a/python/tests/test_array.py b/python/tests/test_array.py index c02b524b4..3ab41bef7 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1187,7 +1187,7 @@ class TestArray(mlx_tests.MLXTestCase): check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1])) check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1])) check_slices( - np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 0, 1]) + np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 2, 1]) ) # Multiple slices diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 02ada39b4..8521d8f80 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2586,17 +2586,6 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqualArray(result, mx.array(expected)) def test_atleast_1d(self): - def compare_nested_lists(x, y): - if isinstance(x, list) and isinstance(y, list): - if len(x) != len(y): - return False - for i in range(len(x)): - if not compare_nested_lists(x[i], y[i]): - return False - return True - else: - return x == y - # Test 1D input arrays = [ [1], @@ -2614,23 +2603,11 @@ class TestOps(mlx_tests.MLXTestCase): for i, array in enumerate(arrays): mx_res = mx.atleast_1d(mx.array(array)) np_res = np.atleast_1d(np.array(array)) - self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) self.assertEqual(mx_res.shape, np_res.shape) self.assertEqual(mx_res.ndim, np_res.ndim) - self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i]))) + self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i])) def test_atleast_2d(self): - def compare_nested_lists(x, y): - if isinstance(x, list) and isinstance(y, list): - if len(x) != len(y): - return False - for i in range(len(x)): - if not compare_nested_lists(x[i], y[i]): - return False - return True - else: - return x == y - # Test 1D input arrays = [ [1], @@ -2648,23 +2625,11 @@ class TestOps(mlx_tests.MLXTestCase): for i, array in enumerate(arrays): mx_res = mx.atleast_2d(mx.array(array)) np_res = np.atleast_2d(np.array(array)) - self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) self.assertEqual(mx_res.shape, np_res.shape) self.assertEqual(mx_res.ndim, np_res.ndim) - self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i]))) + self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i])) def test_atleast_3d(self): - def compare_nested_lists(x, y): - if isinstance(x, list) and isinstance(y, list): - if len(x) != len(y): - return False - for i in range(len(x)): - if not compare_nested_lists(x[i], y[i]): - return False - return True - else: - return x == y - # Test 1D input arrays = [ [1], @@ -2682,10 +2647,9 @@ class TestOps(mlx_tests.MLXTestCase): for i, array in enumerate(arrays): mx_res = mx.atleast_3d(mx.array(array)) np_res = np.atleast_3d(np.array(array)) - self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) self.assertEqual(mx_res.shape, np_res.shape) self.assertEqual(mx_res.ndim, np_res.ndim) - self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i]))) + self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i])) def test_issubdtype(self): self.assertTrue(mx.issubdtype(mx.bfloat16, mx.inexact))