mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Compare commits
2 Commits
85869fda0c
...
91817a165b
Author | SHA1 | Date | |
---|---|---|---|
![]() |
91817a165b | ||
![]() |
14531cb14f |
@ -107,6 +107,16 @@ same array:
|
||||
>>> a
|
||||
array([1, 2, 0], dtype=int32)
|
||||
|
||||
|
||||
Note, unlike NumPy, updates to the same location are nondeterministic:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.array([1, 2, 3])
|
||||
>>> a[[0, 0]] = mx.array([4, 5])
|
||||
|
||||
The first element of ``a`` could be ``4`` or ``5``.
|
||||
|
||||
Transformations of functions which use in-place updates are allowed and work as
|
||||
expected. For example:
|
||||
|
||||
|
@ -196,8 +196,8 @@ void binary_op_gpu_inplace(
|
||||
} else if (bopt == BinaryOpType::VectorVector) {
|
||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
||||
}
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
|
@ -10,13 +10,13 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
|
||||
using InType = cuda_type_t<CTYPE_IN>; \
|
||||
using OutType = cuda_type_t<CTYPE_OUT>; \
|
||||
__VA_ARGS__; \
|
||||
}); \
|
||||
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
|
||||
using InType = cuda_type_t<CTYPE_IN>; \
|
||||
using OutType = cuda_type_t<CTYPE_OUT>; \
|
||||
__VA_ARGS__; \
|
||||
}); \
|
||||
})
|
||||
|
||||
void copy_contiguous(
|
||||
|
@ -43,8 +43,8 @@ void copy_contiguous(
|
||||
if (ctype == CopyType::Vector) {
|
||||
kernel = cu::copy_v<InType, OutType, IdxT>;
|
||||
}
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in.data<InType>() + in_offset,
|
||||
out.data<OutType>() + out_offset,
|
||||
|
@ -57,7 +57,6 @@ struct CastOp<
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// Return an iterator that cast the value to DstT using CastOp.
|
||||
template <typename DstT, typename Iterator>
|
||||
__host__ __device__ auto make_cast_iterator(Iterator it) {
|
||||
|
@ -5,6 +5,8 @@
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <math_constants.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
struct Abs {
|
||||
@ -183,21 +185,38 @@ struct Imag {
|
||||
struct Log {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return log(x);
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
auto r = log(cuCrealf(Abs{}(x)));
|
||||
auto i = atan2f(cuCimagf(x), cuCrealf(x));
|
||||
return {r, i};
|
||||
} else {
|
||||
return log(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Log2 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return log2(x);
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
auto y = Log{}(x);
|
||||
return {cuCrealf(y) / CUDART_LN2_F, cuCimagf(y) / CUDART_LN2_F};
|
||||
} else {
|
||||
return log2(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Log10 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return log10(x);
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
auto y = Log{}(x);
|
||||
return {cuCrealf(y) / CUDART_LNT_F, cuCimagf(y) / CUDART_LNT_F};
|
||||
return y;
|
||||
} else {
|
||||
return log10(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -102,6 +102,11 @@ inline constexpr bool is_floating_v =
|
||||
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
|
||||
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
|
||||
|
||||
// Type traits for detecting complex or real floating point numbers.
|
||||
template <typename T>
|
||||
inline constexpr bool is_inexact_v =
|
||||
is_floating_v<T> || cuda::std::is_same_v<T, complex64_t>;
|
||||
|
||||
// Utility to copy data from vector to array in host.
|
||||
template <int NDIM = MAX_NDIM, typename T = int32_t>
|
||||
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) {
|
||||
@ -162,7 +167,8 @@ inline std::tuple<dim3, uint> get_launch_args(
|
||||
const array& arr,
|
||||
bool large,
|
||||
int work_per_thread = 1) {
|
||||
return get_launch_args(kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
|
||||
return get_launch_args(
|
||||
kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -142,7 +142,8 @@ void ternary_op_gpu_inplace(
|
||||
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
auto kernel = cu::ternary_v<Op, DType, IdxT>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
|
@ -28,11 +28,14 @@ constexpr bool supports_unary_op() {
|
||||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
|
||||
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
|
||||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Log1p> ||
|
||||
std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
||||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Sigmoid> ||
|
||||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
|
||||
std::is_same_v<Op, Sigmoid> || std::is_same_v<Op, Sqrt> ||
|
||||
std::is_same_v<Op, Rsqrt>) {
|
||||
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
||||
std::is_same_v<Op, Log10>) {
|
||||
return std::is_same_v<In, Out> && is_inexact_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, BitwiseInvert>) {
|
||||
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
|
||||
!std::is_same_v<In, bool>;
|
||||
|
@ -1,6 +1,5 @@
|
||||
cuda_skip = {
|
||||
"TestArray.test_api",
|
||||
"TestArray.test_setitem",
|
||||
"TestAutograd.test_cumprod_grad",
|
||||
"TestAutograd.test_slice_grads",
|
||||
"TestAutograd.test_split_against_slice",
|
||||
@ -51,7 +50,6 @@ cuda_skip = {
|
||||
"TestEinsum.test_opt_einsum_test_cases",
|
||||
"TestEval.test_multi_output_eval_during_transform",
|
||||
"TestExportImport.test_export_conv",
|
||||
"TestFast.test_rope_grad",
|
||||
"TestFFT.test_fft",
|
||||
"TestFFT.test_fft_big_powers_of_two",
|
||||
"TestFFT.test_fft_contiguity",
|
||||
@ -89,9 +87,6 @@ cuda_skip = {
|
||||
"TestOps.test_argpartition",
|
||||
"TestOps.test_array_equal",
|
||||
"TestOps.test_as_strided",
|
||||
"TestOps.test_atleast_1d",
|
||||
"TestOps.test_atleast_2d",
|
||||
"TestOps.test_atleast_3d",
|
||||
"TestOps.test_binary_ops",
|
||||
"TestOps.test_bitwise_grad",
|
||||
"TestOps.test_complex_ops",
|
||||
@ -101,15 +96,11 @@ cuda_skip = {
|
||||
"TestOps.test_hadamard_grad_vmap",
|
||||
"TestOps.test_irregular_binary_ops",
|
||||
"TestOps.test_kron",
|
||||
"TestOps.test_log",
|
||||
"TestOps.test_log10",
|
||||
"TestOps.test_log1p",
|
||||
"TestOps.test_log2",
|
||||
"TestOps.test_logaddexp",
|
||||
"TestOps.test_logcumsumexp",
|
||||
"TestOps.test_partition",
|
||||
"TestOps.test_scans",
|
||||
"TestOps.test_slice_update_reversed",
|
||||
"TestOps.test_softmax",
|
||||
"TestOps.test_sort",
|
||||
"TestOps.test_tensordot",
|
||||
|
@ -1187,7 +1187,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))
|
||||
check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))
|
||||
check_slices(
|
||||
np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 0, 1])
|
||||
np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 2, 1])
|
||||
)
|
||||
|
||||
# Multiple slices
|
||||
|
@ -2586,17 +2586,6 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqualArray(result, mx.array(expected))
|
||||
|
||||
def test_atleast_1d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
if isinstance(x, list) and isinstance(y, list):
|
||||
if len(x) != len(y):
|
||||
return False
|
||||
for i in range(len(x)):
|
||||
if not compare_nested_lists(x[i], y[i]):
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return x == y
|
||||
|
||||
# Test 1D input
|
||||
arrays = [
|
||||
[1],
|
||||
@ -2614,23 +2603,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
for i, array in enumerate(arrays):
|
||||
mx_res = mx.atleast_1d(mx.array(array))
|
||||
np_res = np.atleast_1d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
|
||||
|
||||
def test_atleast_2d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
if isinstance(x, list) and isinstance(y, list):
|
||||
if len(x) != len(y):
|
||||
return False
|
||||
for i in range(len(x)):
|
||||
if not compare_nested_lists(x[i], y[i]):
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return x == y
|
||||
|
||||
# Test 1D input
|
||||
arrays = [
|
||||
[1],
|
||||
@ -2648,23 +2625,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
for i, array in enumerate(arrays):
|
||||
mx_res = mx.atleast_2d(mx.array(array))
|
||||
np_res = np.atleast_2d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
|
||||
|
||||
def test_atleast_3d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
if isinstance(x, list) and isinstance(y, list):
|
||||
if len(x) != len(y):
|
||||
return False
|
||||
for i in range(len(x)):
|
||||
if not compare_nested_lists(x[i], y[i]):
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return x == y
|
||||
|
||||
# Test 1D input
|
||||
arrays = [
|
||||
[1],
|
||||
@ -2682,10 +2647,9 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
for i, array in enumerate(arrays):
|
||||
mx_res = mx.atleast_3d(mx.array(array))
|
||||
np_res = np.atleast_3d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
|
||||
|
||||
def test_issubdtype(self):
|
||||
self.assertTrue(mx.issubdtype(mx.bfloat16, mx.inexact))
|
||||
|
Loading…
Reference in New Issue
Block a user