mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-13 20:56:45 +08:00
enable more tests
This commit is contained in:
parent
85869fda0c
commit
14531cb14f
@ -107,6 +107,16 @@ same array:
|
|||||||
>>> a
|
>>> a
|
||||||
array([1, 2, 0], dtype=int32)
|
array([1, 2, 0], dtype=int32)
|
||||||
|
|
||||||
|
|
||||||
|
Note, unlike NumPy, updates to the same location are nondeterministic:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.array([1, 2, 3])
|
||||||
|
>>> a[[0, 0]] = mx.array([4, 5])
|
||||||
|
|
||||||
|
The first element of ``a`` could be ``4`` or ``5``.
|
||||||
|
|
||||||
Transformations of functions which use in-place updates are allowed and work as
|
Transformations of functions which use in-place updates are allowed and work as
|
||||||
expected. For example:
|
expected. For example:
|
||||||
|
|
||||||
|
@ -5,6 +5,8 @@
|
|||||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||||
#include "mlx/backend/cuda/device/utils.cuh"
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
|
||||||
|
#include <math_constants.h>
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
struct Abs {
|
struct Abs {
|
||||||
@ -183,22 +185,39 @@ struct Imag {
|
|||||||
struct Log {
|
struct Log {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
auto r = log(cuCrealf(Abs{}(x)));
|
||||||
|
auto i = atan2f(cuCimagf(x), cuCrealf(x));
|
||||||
|
return {r, i};
|
||||||
|
} else {
|
||||||
return log(x);
|
return log(x);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Log2 {
|
struct Log2 {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
auto y = Log{}(x);
|
||||||
|
return {cuCrealf(y) / CUDART_LN2_F, cuCimagf(y) / CUDART_LN2_F};
|
||||||
|
} else {
|
||||||
return log2(x);
|
return log2(x);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Log10 {
|
struct Log10 {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
auto y = Log{}(x);
|
||||||
|
return {cuCrealf(y) / CUDART_LNT_F, cuCimagf(y) / CUDART_LNT_F};
|
||||||
|
return y;
|
||||||
|
} else {
|
||||||
return log10(x);
|
return log10(x);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Log1p {
|
struct Log1p {
|
||||||
|
@ -102,6 +102,11 @@ inline constexpr bool is_floating_v =
|
|||||||
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
|
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
|
||||||
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
|
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
|
||||||
|
|
||||||
|
// Type traits for detecting complex or real floating point numbers.
|
||||||
|
template <typename T>
|
||||||
|
inline constexpr bool is_inexact_v =
|
||||||
|
is_floating_v<T> || cuda::std::is_same_v<T, complex64_t>;
|
||||||
|
|
||||||
// Utility to copy data from vector to array in host.
|
// Utility to copy data from vector to array in host.
|
||||||
template <int NDIM = MAX_NDIM, typename T = int32_t>
|
template <int NDIM = MAX_NDIM, typename T = int32_t>
|
||||||
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) {
|
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) {
|
||||||
|
@ -28,11 +28,14 @@ constexpr bool supports_unary_op() {
|
|||||||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
|
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
|
||||||
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
|
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
|
||||||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Log1p> ||
|
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Log1p> ||
|
||||||
std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
std::is_same_v<Op, Sigmoid> ||
|
||||||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Sigmoid> ||
|
|
||||||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
|
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
|
||||||
return std::is_same_v<In, Out> && is_floating_v<In>;
|
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||||
}
|
}
|
||||||
|
if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
||||||
|
std::is_same_v<Op, Log10>) {
|
||||||
|
return std::is_same_v<In, Out> && is_inexact_v<In>;
|
||||||
|
}
|
||||||
if (std::is_same_v<Op, BitwiseInvert>) {
|
if (std::is_same_v<Op, BitwiseInvert>) {
|
||||||
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
|
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
|
||||||
!std::is_same_v<In, bool>;
|
!std::is_same_v<In, bool>;
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
cuda_skip = {
|
cuda_skip = {
|
||||||
"TestArray.test_api",
|
"TestArray.test_api",
|
||||||
"TestArray.test_setitem",
|
|
||||||
"TestAutograd.test_cumprod_grad",
|
"TestAutograd.test_cumprod_grad",
|
||||||
"TestAutograd.test_slice_grads",
|
"TestAutograd.test_slice_grads",
|
||||||
"TestAutograd.test_split_against_slice",
|
"TestAutograd.test_split_against_slice",
|
||||||
@ -51,7 +50,6 @@ cuda_skip = {
|
|||||||
"TestEinsum.test_opt_einsum_test_cases",
|
"TestEinsum.test_opt_einsum_test_cases",
|
||||||
"TestEval.test_multi_output_eval_during_transform",
|
"TestEval.test_multi_output_eval_during_transform",
|
||||||
"TestExportImport.test_export_conv",
|
"TestExportImport.test_export_conv",
|
||||||
"TestFast.test_rope_grad",
|
|
||||||
"TestFFT.test_fft",
|
"TestFFT.test_fft",
|
||||||
"TestFFT.test_fft_big_powers_of_two",
|
"TestFFT.test_fft_big_powers_of_two",
|
||||||
"TestFFT.test_fft_contiguity",
|
"TestFFT.test_fft_contiguity",
|
||||||
@ -89,9 +87,6 @@ cuda_skip = {
|
|||||||
"TestOps.test_argpartition",
|
"TestOps.test_argpartition",
|
||||||
"TestOps.test_array_equal",
|
"TestOps.test_array_equal",
|
||||||
"TestOps.test_as_strided",
|
"TestOps.test_as_strided",
|
||||||
"TestOps.test_atleast_1d",
|
|
||||||
"TestOps.test_atleast_2d",
|
|
||||||
"TestOps.test_atleast_3d",
|
|
||||||
"TestOps.test_binary_ops",
|
"TestOps.test_binary_ops",
|
||||||
"TestOps.test_bitwise_grad",
|
"TestOps.test_bitwise_grad",
|
||||||
"TestOps.test_complex_ops",
|
"TestOps.test_complex_ops",
|
||||||
@ -101,15 +96,11 @@ cuda_skip = {
|
|||||||
"TestOps.test_hadamard_grad_vmap",
|
"TestOps.test_hadamard_grad_vmap",
|
||||||
"TestOps.test_irregular_binary_ops",
|
"TestOps.test_irregular_binary_ops",
|
||||||
"TestOps.test_kron",
|
"TestOps.test_kron",
|
||||||
"TestOps.test_log",
|
|
||||||
"TestOps.test_log10",
|
|
||||||
"TestOps.test_log1p",
|
"TestOps.test_log1p",
|
||||||
"TestOps.test_log2",
|
|
||||||
"TestOps.test_logaddexp",
|
"TestOps.test_logaddexp",
|
||||||
"TestOps.test_logcumsumexp",
|
"TestOps.test_logcumsumexp",
|
||||||
"TestOps.test_partition",
|
"TestOps.test_partition",
|
||||||
"TestOps.test_scans",
|
"TestOps.test_scans",
|
||||||
"TestOps.test_slice_update_reversed",
|
|
||||||
"TestOps.test_softmax",
|
"TestOps.test_softmax",
|
||||||
"TestOps.test_sort",
|
"TestOps.test_sort",
|
||||||
"TestOps.test_tensordot",
|
"TestOps.test_tensordot",
|
||||||
|
@ -1187,7 +1187,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))
|
check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))
|
||||||
check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))
|
check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))
|
||||||
check_slices(
|
check_slices(
|
||||||
np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 0, 1])
|
np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 2, 1])
|
||||||
)
|
)
|
||||||
|
|
||||||
# Multiple slices
|
# Multiple slices
|
||||||
|
@ -2586,17 +2586,6 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqualArray(result, mx.array(expected))
|
self.assertEqualArray(result, mx.array(expected))
|
||||||
|
|
||||||
def test_atleast_1d(self):
|
def test_atleast_1d(self):
|
||||||
def compare_nested_lists(x, y):
|
|
||||||
if isinstance(x, list) and isinstance(y, list):
|
|
||||||
if len(x) != len(y):
|
|
||||||
return False
|
|
||||||
for i in range(len(x)):
|
|
||||||
if not compare_nested_lists(x[i], y[i]):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return x == y
|
|
||||||
|
|
||||||
# Test 1D input
|
# Test 1D input
|
||||||
arrays = [
|
arrays = [
|
||||||
[1],
|
[1],
|
||||||
@ -2614,23 +2603,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
for i, array in enumerate(arrays):
|
for i, array in enumerate(arrays):
|
||||||
mx_res = mx.atleast_1d(mx.array(array))
|
mx_res = mx.atleast_1d(mx.array(array))
|
||||||
np_res = np.atleast_1d(np.array(array))
|
np_res = np.atleast_1d(np.array(array))
|
||||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
|
||||||
self.assertEqual(mx_res.shape, np_res.shape)
|
self.assertEqual(mx_res.shape, np_res.shape)
|
||||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
|
||||||
|
|
||||||
def test_atleast_2d(self):
|
def test_atleast_2d(self):
|
||||||
def compare_nested_lists(x, y):
|
|
||||||
if isinstance(x, list) and isinstance(y, list):
|
|
||||||
if len(x) != len(y):
|
|
||||||
return False
|
|
||||||
for i in range(len(x)):
|
|
||||||
if not compare_nested_lists(x[i], y[i]):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return x == y
|
|
||||||
|
|
||||||
# Test 1D input
|
# Test 1D input
|
||||||
arrays = [
|
arrays = [
|
||||||
[1],
|
[1],
|
||||||
@ -2648,23 +2625,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
for i, array in enumerate(arrays):
|
for i, array in enumerate(arrays):
|
||||||
mx_res = mx.atleast_2d(mx.array(array))
|
mx_res = mx.atleast_2d(mx.array(array))
|
||||||
np_res = np.atleast_2d(np.array(array))
|
np_res = np.atleast_2d(np.array(array))
|
||||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
|
||||||
self.assertEqual(mx_res.shape, np_res.shape)
|
self.assertEqual(mx_res.shape, np_res.shape)
|
||||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
|
||||||
|
|
||||||
def test_atleast_3d(self):
|
def test_atleast_3d(self):
|
||||||
def compare_nested_lists(x, y):
|
|
||||||
if isinstance(x, list) and isinstance(y, list):
|
|
||||||
if len(x) != len(y):
|
|
||||||
return False
|
|
||||||
for i in range(len(x)):
|
|
||||||
if not compare_nested_lists(x[i], y[i]):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return x == y
|
|
||||||
|
|
||||||
# Test 1D input
|
# Test 1D input
|
||||||
arrays = [
|
arrays = [
|
||||||
[1],
|
[1],
|
||||||
@ -2682,10 +2647,9 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
for i, array in enumerate(arrays):
|
for i, array in enumerate(arrays):
|
||||||
mx_res = mx.atleast_3d(mx.array(array))
|
mx_res = mx.atleast_3d(mx.array(array))
|
||||||
np_res = np.atleast_3d(np.array(array))
|
np_res = np.atleast_3d(np.array(array))
|
||||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
|
||||||
self.assertEqual(mx_res.shape, np_res.shape)
|
self.assertEqual(mx_res.shape, np_res.shape)
|
||||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
|
||||||
|
|
||||||
def test_issubdtype(self):
|
def test_issubdtype(self):
|
||||||
self.assertTrue(mx.issubdtype(mx.bfloat16, mx.inexact))
|
self.assertTrue(mx.issubdtype(mx.bfloat16, mx.inexact))
|
||||||
|
Loading…
Reference in New Issue
Block a user