mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
format
This commit is contained in:
parent
ff1f9ca5e8
commit
abcf62ee55
@ -1,7 +1,7 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
|
||||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||||
|
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||||
#include "mlx/backend/cuda/device/utils.cuh"
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
|
||||||
#include <cuComplex.h>
|
#include <cuComplex.h>
|
||||||
@ -128,7 +128,9 @@ struct LogAddExp {
|
|||||||
__device__ cuComplex operator()(cuComplex x, cuComplex y) {
|
__device__ cuComplex operator()(cuComplex x, cuComplex y) {
|
||||||
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) ||
|
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) ||
|
||||||
isnan(cuCimagf(y))) {
|
isnan(cuCimagf(y))) {
|
||||||
return {cuda::std::numeric_limits<float>::quiet_NaN(), cuda::std::numeric_limits<float>::quiet_NaN()};
|
return {
|
||||||
|
cuda::std::numeric_limits<float>::quiet_NaN(),
|
||||||
|
cuda::std::numeric_limits<float>::quiet_NaN()};
|
||||||
}
|
}
|
||||||
constexpr float inf = cuda::std::numeric_limits<float>::infinity();
|
constexpr float inf = cuda::std::numeric_limits<float>::infinity();
|
||||||
auto maxval = x > y ? x : y;
|
auto maxval = x > y ? x : y;
|
||||||
|
@ -27,9 +27,8 @@ constexpr bool supports_unary_op() {
|
|||||||
std::is_same_v<Op, ArcSin> || std::is_same_v<Op, ArcSinh> ||
|
std::is_same_v<Op, ArcSin> || std::is_same_v<Op, ArcSinh> ||
|
||||||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
|
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
|
||||||
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
|
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
|
||||||
std::is_same_v<Op, Expm1> ||
|
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Sigmoid> ||
|
||||||
std::is_same_v<Op, Sigmoid> || std::is_same_v<Op, Sqrt> ||
|
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
|
||||||
std::is_same_v<Op, Rsqrt>) {
|
|
||||||
return std::is_same_v<In, Out> && is_floating_v<In>;
|
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||||
}
|
}
|
||||||
if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
||||||
|
@ -25,32 +25,25 @@ cuda_skip = {
|
|||||||
"TestReduce.test_expand_sums",
|
"TestReduce.test_expand_sums",
|
||||||
"TestReduce.test_many_reduction_axes",
|
"TestReduce.test_many_reduction_axes",
|
||||||
"TestUpsample.test_torch_upsample",
|
"TestUpsample.test_torch_upsample",
|
||||||
|
|
||||||
# DivMod NYI
|
# DivMod NYI
|
||||||
"TestOps.test_divmod",
|
"TestOps.test_divmod",
|
||||||
"TestEval.test_multi_output_eval_during_transform",
|
"TestEval.test_multi_output_eval_during_transform",
|
||||||
|
|
||||||
# Partition NYI
|
# Partition NYI
|
||||||
"TestAutograd.test_topk_grad",
|
"TestAutograd.test_topk_grad",
|
||||||
"TestOps.test_argpartition",
|
"TestOps.test_argpartition",
|
||||||
"TestOps.test_partition",
|
"TestOps.test_partition",
|
||||||
|
|
||||||
# Block masked matmul NYI
|
# Block masked matmul NYI
|
||||||
"TestBlas.test_block_masked_matmul",
|
"TestBlas.test_block_masked_matmul",
|
||||||
|
|
||||||
# Gather matmul NYI
|
# Gather matmul NYI
|
||||||
"TestBlas.test_gather_matmul",
|
"TestBlas.test_gather_matmul",
|
||||||
"TestBlas.test_gather_matmul_grad",
|
"TestBlas.test_gather_matmul_grad",
|
||||||
|
|
||||||
# Scan NYI
|
# Scan NYI
|
||||||
"TestAutograd.test_cumprod_grad",
|
"TestAutograd.test_cumprod_grad",
|
||||||
"TestOps.test_scans",
|
"TestOps.test_scans",
|
||||||
"TestOps.test_logcumsumexp",
|
"TestOps.test_logcumsumexp",
|
||||||
|
|
||||||
# Hadamard NYI
|
# Hadamard NYI
|
||||||
"TestOps.test_hadamard",
|
"TestOps.test_hadamard",
|
||||||
"TestOps.test_hadamard_grad_vmap",
|
"TestOps.test_hadamard_grad_vmap",
|
||||||
|
|
||||||
# Convolutions NYI
|
# Convolutions NYI
|
||||||
"TestConv.test_1d_conv_with_2d",
|
"TestConv.test_1d_conv_with_2d",
|
||||||
"TestConv.test_asymmetric_padding",
|
"TestConv.test_asymmetric_padding",
|
||||||
@ -82,7 +75,6 @@ cuda_skip = {
|
|||||||
"TestLayers.test_conv1d",
|
"TestLayers.test_conv1d",
|
||||||
"TestLayers.test_conv2d",
|
"TestLayers.test_conv2d",
|
||||||
"TestVmap.test_vmap_conv",
|
"TestVmap.test_vmap_conv",
|
||||||
|
|
||||||
# FFTs NYI
|
# FFTs NYI
|
||||||
"TestFFT.test_fft",
|
"TestFFT.test_fft",
|
||||||
"TestFFT.test_fft_big_powers_of_two",
|
"TestFFT.test_fft_big_powers_of_two",
|
||||||
@ -93,7 +85,6 @@ cuda_skip = {
|
|||||||
"TestFFT.test_fft_large_numbers",
|
"TestFFT.test_fft_large_numbers",
|
||||||
"TestFFT.test_fft_shared_mem",
|
"TestFFT.test_fft_shared_mem",
|
||||||
"TestFFT.test_fftn",
|
"TestFFT.test_fftn",
|
||||||
|
|
||||||
# Lapack ops NYI
|
# Lapack ops NYI
|
||||||
"TestLinalg.test_cholesky",
|
"TestLinalg.test_cholesky",
|
||||||
"TestLinalg.test_cholesky_inv",
|
"TestLinalg.test_cholesky_inv",
|
||||||
@ -109,7 +100,6 @@ cuda_skip = {
|
|||||||
"TestLinalg.test_svd_decomposition",
|
"TestLinalg.test_svd_decomposition",
|
||||||
"TestVmap.test_vmap_svd",
|
"TestVmap.test_vmap_svd",
|
||||||
"TestLinalg.test_tri_inverse",
|
"TestLinalg.test_tri_inverse",
|
||||||
|
|
||||||
# Quantization NYI
|
# Quantization NYI
|
||||||
"TestQuantized.test_gather_matmul_grad",
|
"TestQuantized.test_gather_matmul_grad",
|
||||||
"TestQuantized.test_gather_qmm",
|
"TestQuantized.test_gather_qmm",
|
||||||
|
Loading…
Reference in New Issue
Block a user