Compare commits

...

2 Commits

Author SHA1 Message Date
Awni Hannun
91817a165b format 2025-06-16 07:46:40 -07:00
Awni Hannun
14531cb14f enable more tests 2025-06-16 07:45:01 -07:00
12 changed files with 62 additions and 69 deletions

View File

@ -107,6 +107,16 @@ same array:
>>> a >>> a
array([1, 2, 0], dtype=int32) array([1, 2, 0], dtype=int32)
Note, unlike NumPy, updates to the same location are nondeterministic:
.. code-block:: shell
>>> a = mx.array([1, 2, 3])
>>> a[[0, 0]] = mx.array([4, 5])
The first element of ``a`` could be ``4`` or ``5``.
Transformations of functions which use in-place updates are allowed and work as Transformations of functions which use in-place updates are allowed and work as
expected. For example: expected. For example:

View File

@ -196,8 +196,8 @@ void binary_op_gpu_inplace(
} else if (bopt == BinaryOpType::VectorVector) { } else if (bopt == BinaryOpType::VectorVector) {
kernel = cu::binary_vv<Op, InType, OutType, IdxT>; kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
} }
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] = get_launch_args(
get_launch_args(kernel, out.data_size(), out.shape(), out.strides(), LARGE); kernel, out.data_size(), out.shape(), out.strides(), LARGE);
kernel<<<num_blocks, block_dims, 0, stream>>>( kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(), a.data<InType>(),
b.data<InType>(), b.data<InType>(),

View File

@ -10,13 +10,13 @@
namespace mlx::core { namespace mlx::core {
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \ #define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \ MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \ MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
using InType = cuda_type_t<CTYPE_IN>; \ using InType = cuda_type_t<CTYPE_IN>; \
using OutType = cuda_type_t<CTYPE_OUT>; \ using OutType = cuda_type_t<CTYPE_OUT>; \
__VA_ARGS__; \ __VA_ARGS__; \
}); \ }); \
}) })
void copy_contiguous( void copy_contiguous(

View File

@ -43,8 +43,8 @@ void copy_contiguous(
if (ctype == CopyType::Vector) { if (ctype == CopyType::Vector) {
kernel = cu::copy_v<InType, OutType, IdxT>; kernel = cu::copy_v<InType, OutType, IdxT>;
} }
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] = get_launch_args(
get_launch_args(kernel, out.data_size(), out.shape(), out.strides(), LARGE); kernel, out.data_size(), out.shape(), out.strides(), LARGE);
kernel<<<num_blocks, block_dims, 0, stream>>>( kernel<<<num_blocks, block_dims, 0, stream>>>(
in.data<InType>() + in_offset, in.data<InType>() + in_offset,
out.data<OutType>() + out_offset, out.data<OutType>() + out_offset,

View File

@ -57,7 +57,6 @@ struct CastOp<
} }
}; };
// Return an iterator that cast the value to DstT using CastOp. // Return an iterator that cast the value to DstT using CastOp.
template <typename DstT, typename Iterator> template <typename DstT, typename Iterator>
__host__ __device__ auto make_cast_iterator(Iterator it) { __host__ __device__ auto make_cast_iterator(Iterator it) {

View File

@ -5,6 +5,8 @@
#include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/device/utils.cuh"
#include <math_constants.h>
namespace mlx::core::cu { namespace mlx::core::cu {
struct Abs { struct Abs {
@ -183,21 +185,38 @@ struct Imag {
struct Log { struct Log {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
return log(x); if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto r = log(cuCrealf(Abs{}(x)));
auto i = atan2f(cuCimagf(x), cuCrealf(x));
return {r, i};
} else {
return log(x);
}
} }
}; };
struct Log2 { struct Log2 {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
return log2(x); if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto y = Log{}(x);
return {cuCrealf(y) / CUDART_LN2_F, cuCimagf(y) / CUDART_LN2_F};
} else {
return log2(x);
}
} }
}; };
struct Log10 { struct Log10 {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
return log10(x); if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto y = Log{}(x);
return {cuCrealf(y) / CUDART_LNT_F, cuCimagf(y) / CUDART_LNT_F};
return y;
} else {
return log10(x);
}
} }
}; };

View File

@ -102,6 +102,11 @@ inline constexpr bool is_floating_v =
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> || cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>; cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
// Type traits for detecting complex or real floating point numbers.
template <typename T>
inline constexpr bool is_inexact_v =
is_floating_v<T> || cuda::std::is_same_v<T, complex64_t>;
// Utility to copy data from vector to array in host. // Utility to copy data from vector to array in host.
template <int NDIM = MAX_NDIM, typename T = int32_t> template <int NDIM = MAX_NDIM, typename T = int32_t>
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) { inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) {
@ -162,7 +167,8 @@ inline std::tuple<dim3, uint> get_launch_args(
const array& arr, const array& arr,
bool large, bool large,
int work_per_thread = 1) { int work_per_thread = 1) {
return get_launch_args(kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread); return get_launch_args(
kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -142,7 +142,8 @@ void ternary_op_gpu_inplace(
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>; using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
auto kernel = cu::ternary_v<Op, DType, IdxT>; auto kernel = cu::ternary_v<Op, DType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out.data_size(), out.shape(), out.strides(), LARGE); auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
kernel<<<num_blocks, block_dims, 0, stream>>>( kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<bool>(), a.data<bool>(),
b.data<DType>(), b.data<DType>(),

View File

@ -28,11 +28,14 @@ constexpr bool supports_unary_op() {
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> || std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> || std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Log1p> || std::is_same_v<Op, Expm1> || std::is_same_v<Op, Log1p> ||
std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> || std::is_same_v<Op, Sigmoid> || std::is_same_v<Op, Sqrt> ||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Sigmoid> || std::is_same_v<Op, Rsqrt>) {
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
return std::is_same_v<In, Out> && is_floating_v<In>; return std::is_same_v<In, Out> && is_floating_v<In>;
} }
if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
std::is_same_v<Op, Log10>) {
return std::is_same_v<In, Out> && is_inexact_v<In>;
}
if (std::is_same_v<Op, BitwiseInvert>) { if (std::is_same_v<Op, BitwiseInvert>) {
return std::is_same_v<In, Out> && std::is_integral_v<In> && return std::is_same_v<In, Out> && std::is_integral_v<In> &&
!std::is_same_v<In, bool>; !std::is_same_v<In, bool>;

View File

@ -1,6 +1,5 @@
cuda_skip = { cuda_skip = {
"TestArray.test_api", "TestArray.test_api",
"TestArray.test_setitem",
"TestAutograd.test_cumprod_grad", "TestAutograd.test_cumprod_grad",
"TestAutograd.test_slice_grads", "TestAutograd.test_slice_grads",
"TestAutograd.test_split_against_slice", "TestAutograd.test_split_against_slice",
@ -51,7 +50,6 @@ cuda_skip = {
"TestEinsum.test_opt_einsum_test_cases", "TestEinsum.test_opt_einsum_test_cases",
"TestEval.test_multi_output_eval_during_transform", "TestEval.test_multi_output_eval_during_transform",
"TestExportImport.test_export_conv", "TestExportImport.test_export_conv",
"TestFast.test_rope_grad",
"TestFFT.test_fft", "TestFFT.test_fft",
"TestFFT.test_fft_big_powers_of_two", "TestFFT.test_fft_big_powers_of_two",
"TestFFT.test_fft_contiguity", "TestFFT.test_fft_contiguity",
@ -89,9 +87,6 @@ cuda_skip = {
"TestOps.test_argpartition", "TestOps.test_argpartition",
"TestOps.test_array_equal", "TestOps.test_array_equal",
"TestOps.test_as_strided", "TestOps.test_as_strided",
"TestOps.test_atleast_1d",
"TestOps.test_atleast_2d",
"TestOps.test_atleast_3d",
"TestOps.test_binary_ops", "TestOps.test_binary_ops",
"TestOps.test_bitwise_grad", "TestOps.test_bitwise_grad",
"TestOps.test_complex_ops", "TestOps.test_complex_ops",
@ -101,15 +96,11 @@ cuda_skip = {
"TestOps.test_hadamard_grad_vmap", "TestOps.test_hadamard_grad_vmap",
"TestOps.test_irregular_binary_ops", "TestOps.test_irregular_binary_ops",
"TestOps.test_kron", "TestOps.test_kron",
"TestOps.test_log",
"TestOps.test_log10",
"TestOps.test_log1p", "TestOps.test_log1p",
"TestOps.test_log2",
"TestOps.test_logaddexp", "TestOps.test_logaddexp",
"TestOps.test_logcumsumexp", "TestOps.test_logcumsumexp",
"TestOps.test_partition", "TestOps.test_partition",
"TestOps.test_scans", "TestOps.test_scans",
"TestOps.test_slice_update_reversed",
"TestOps.test_softmax", "TestOps.test_softmax",
"TestOps.test_sort", "TestOps.test_sort",
"TestOps.test_tensordot", "TestOps.test_tensordot",

View File

@ -1187,7 +1187,7 @@ class TestArray(mlx_tests.MLXTestCase):
check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1])) check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))
check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1])) check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))
check_slices( check_slices(
np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 0, 1]) np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 2, 1])
) )
# Multiple slices # Multiple slices

View File

@ -2586,17 +2586,6 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqualArray(result, mx.array(expected)) self.assertEqualArray(result, mx.array(expected))
def test_atleast_1d(self): def test_atleast_1d(self):
def compare_nested_lists(x, y):
if isinstance(x, list) and isinstance(y, list):
if len(x) != len(y):
return False
for i in range(len(x)):
if not compare_nested_lists(x[i], y[i]):
return False
return True
else:
return x == y
# Test 1D input # Test 1D input
arrays = [ arrays = [
[1], [1],
@ -2614,23 +2603,11 @@ class TestOps(mlx_tests.MLXTestCase):
for i, array in enumerate(arrays): for i, array in enumerate(arrays):
mx_res = mx.atleast_1d(mx.array(array)) mx_res = mx.atleast_1d(mx.array(array))
np_res = np.atleast_1d(np.array(array)) np_res = np.atleast_1d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape) self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim) self.assertEqual(mx_res.ndim, np_res.ndim)
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i]))) self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
def test_atleast_2d(self): def test_atleast_2d(self):
def compare_nested_lists(x, y):
if isinstance(x, list) and isinstance(y, list):
if len(x) != len(y):
return False
for i in range(len(x)):
if not compare_nested_lists(x[i], y[i]):
return False
return True
else:
return x == y
# Test 1D input # Test 1D input
arrays = [ arrays = [
[1], [1],
@ -2648,23 +2625,11 @@ class TestOps(mlx_tests.MLXTestCase):
for i, array in enumerate(arrays): for i, array in enumerate(arrays):
mx_res = mx.atleast_2d(mx.array(array)) mx_res = mx.atleast_2d(mx.array(array))
np_res = np.atleast_2d(np.array(array)) np_res = np.atleast_2d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape) self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim) self.assertEqual(mx_res.ndim, np_res.ndim)
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i]))) self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
def test_atleast_3d(self): def test_atleast_3d(self):
def compare_nested_lists(x, y):
if isinstance(x, list) and isinstance(y, list):
if len(x) != len(y):
return False
for i in range(len(x)):
if not compare_nested_lists(x[i], y[i]):
return False
return True
else:
return x == y
# Test 1D input # Test 1D input
arrays = [ arrays = [
[1], [1],
@ -2682,10 +2647,9 @@ class TestOps(mlx_tests.MLXTestCase):
for i, array in enumerate(arrays): for i, array in enumerate(arrays):
mx_res = mx.atleast_3d(mx.array(array)) mx_res = mx.atleast_3d(mx.array(array))
np_res = np.atleast_3d(np.array(array)) np_res = np.atleast_3d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape) self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim) self.assertEqual(mx_res.ndim, np_res.ndim)
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i]))) self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
def test_issubdtype(self): def test_issubdtype(self):
self.assertTrue(mx.issubdtype(mx.bfloat16, mx.inexact)) self.assertTrue(mx.issubdtype(mx.bfloat16, mx.inexact))