mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
more bug fixes
This commit is contained in:
parent
c552ff2451
commit
7429613f76
@ -150,10 +150,10 @@ void binary_op_gpu_inplace(
|
|||||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||||
auto& a_strides = strides[0];
|
auto& a_strides = strides[0];
|
||||||
auto& b_strides = strides[1];
|
auto& b_strides = strides[1];
|
||||||
bool large = a.data_size() > UINT32_MAX ||
|
bool large = a.data_size() > INT32_MAX ||
|
||||||
b.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
b.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||||
MLX_SWITCH_BOOL(large, LARGE, {
|
MLX_SWITCH_BOOL(large, LARGE, {
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||||
|
@ -130,11 +130,13 @@ struct FusedKernelBuilder {
|
|||||||
|
|
||||||
constexpr const char* g_jit_includes = R"(
|
constexpr const char* g_jit_includes = R"(
|
||||||
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||||
|
#include "mlx/backend/cuda/device/ternary_ops.cuh"
|
||||||
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
||||||
#include "mlx/backend/cuda/device/utils.cuh"
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
#define inf cuda::std::numeric_limits<float>::infinity()
|
||||||
)";
|
)";
|
||||||
|
|
||||||
void Compiled::eval_gpu(
|
void Compiled::eval_gpu(
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
#pragma once
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
@ -336,4 +336,21 @@ struct LoopedElemToLoc<1, false, OffsetT> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
inline __device__ cuComplex log1p(cuComplex in) {
|
||||||
|
float x = cuCrealf(in);
|
||||||
|
float y = cuCimagf(in);
|
||||||
|
float zabs = sqrt(x * x + y * y);
|
||||||
|
float theta = atan2f(y, x + 1);
|
||||||
|
if (zabs < 0.5f) {
|
||||||
|
float r = x * (2 + x) + y * y;
|
||||||
|
if (r == 0) { // handle underflow
|
||||||
|
return {x, theta};
|
||||||
|
}
|
||||||
|
return {0.5f * log1pf(r), theta};
|
||||||
|
} else {
|
||||||
|
auto z0 = sqrt((x + 1) * (x + 1) + y * y);
|
||||||
|
return {log(z0), theta};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core::cu
|
||||||
|
@ -101,10 +101,10 @@ void ternary_op_gpu_inplace(
|
|||||||
auto& a_strides = strides[0];
|
auto& a_strides = strides[0];
|
||||||
auto& b_strides = strides[1];
|
auto& b_strides = strides[1];
|
||||||
auto& c_strides = strides[2];
|
auto& c_strides = strides[2];
|
||||||
bool large = a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX ||
|
bool large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||||
c.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
c.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||||
MLX_SWITCH_BOOL(large, LARGE, {
|
MLX_SWITCH_BOOL(large, LARGE, {
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||||
|
@ -27,13 +27,13 @@ constexpr bool supports_unary_op() {
|
|||||||
std::is_same_v<Op, ArcSin> || std::is_same_v<Op, ArcSinh> ||
|
std::is_same_v<Op, ArcSin> || std::is_same_v<Op, ArcSinh> ||
|
||||||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
|
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
|
||||||
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
|
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
|
||||||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Log1p> ||
|
std::is_same_v<Op, Expm1> ||
|
||||||
std::is_same_v<Op, Sigmoid> || std::is_same_v<Op, Sqrt> ||
|
std::is_same_v<Op, Sigmoid> || std::is_same_v<Op, Sqrt> ||
|
||||||
std::is_same_v<Op, Rsqrt>) {
|
std::is_same_v<Op, Rsqrt>) {
|
||||||
return std::is_same_v<In, Out> && is_floating_v<In>;
|
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||||
}
|
}
|
||||||
if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
||||||
std::is_same_v<Op, Log10>) {
|
std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p>) {
|
||||||
return std::is_same_v<In, Out> && is_inexact_v<In>;
|
return std::is_same_v<In, Out> && is_inexact_v<In>;
|
||||||
}
|
}
|
||||||
if (std::is_same_v<Op, BitwiseInvert>) {
|
if (std::is_same_v<Op, BitwiseInvert>) {
|
||||||
|
@ -1,24 +1,68 @@
|
|||||||
cuda_skip = {
|
cuda_skip = {
|
||||||
"TestArray.test_api",
|
"TestArray.test_api",
|
||||||
"TestAutograd.test_cumprod_grad",
|
|
||||||
"TestAutograd.test_slice_grads",
|
"TestAutograd.test_slice_grads",
|
||||||
"TestAutograd.test_split_against_slice",
|
"TestAutograd.test_split_against_slice",
|
||||||
"TestAutograd.test_stop_gradient",
|
"TestAutograd.test_stop_gradient",
|
||||||
"TestAutograd.test_topk_grad",
|
|
||||||
"TestAutograd.test_update_state",
|
"TestAutograd.test_update_state",
|
||||||
"TestAutograd.test_vjp",
|
"TestAutograd.test_vjp",
|
||||||
"TestBF16.test_arg_reduction_ops",
|
"TestBF16.test_arg_reduction_ops",
|
||||||
"TestBF16.test_binary_ops",
|
"TestBF16.test_binary_ops",
|
||||||
"TestBF16.test_reduction_ops",
|
"TestBF16.test_reduction_ops",
|
||||||
"TestBlas.test_block_masked_matmul",
|
|
||||||
"TestBlas.test_complex_gemm",
|
"TestBlas.test_complex_gemm",
|
||||||
"TestBlas.test_gather_matmul",
|
|
||||||
"TestBlas.test_gather_matmul_grad",
|
|
||||||
"TestBlas.test_matmul_batched",
|
"TestBlas.test_matmul_batched",
|
||||||
"TestBlas.test_matrix_vector_attn",
|
"TestBlas.test_matrix_vector_attn",
|
||||||
"TestCompile.test_compile_dynamic_dims",
|
"TestCompile.test_compile_dynamic_dims",
|
||||||
"TestCompile.test_compile_inf",
|
"TestEinsum.test_attention",
|
||||||
"TestCompile.test_inf_constant",
|
"TestEinsum.test_ellipses",
|
||||||
|
"TestEinsum.test_opt_einsum_test_cases",
|
||||||
|
"TestEval.test_multi_output_eval_during_transform",
|
||||||
|
"TestLoad.test_load_f8_e4m3",
|
||||||
|
"TestLosses.test_binary_cross_entropy",
|
||||||
|
"TestMemory.test_memory_info",
|
||||||
|
"TestLayers.test_group_norm",
|
||||||
|
"TestLayers.test_pooling",
|
||||||
|
"TestLayers.test_quantized_embedding",
|
||||||
|
"TestLayers.test_sin_pe",
|
||||||
|
"TestLayers.test_upsample",
|
||||||
|
"TestOps.test_array_equal",
|
||||||
|
"TestOps.test_complex_ops",
|
||||||
|
"TestOps.test_divmod",
|
||||||
|
"TestOps.test_dynamic_slicing",
|
||||||
|
"TestOps.test_irregular_binary_ops",
|
||||||
|
"TestOps.test_kron",
|
||||||
|
"TestOps.test_logaddexp",
|
||||||
|
"TestOps.test_softmax",
|
||||||
|
"TestOps.test_sort",
|
||||||
|
"TestOps.test_tensordot",
|
||||||
|
"TestOps.test_tile",
|
||||||
|
"TestReduce.test_axis_permutation_sums",
|
||||||
|
"TestReduce.test_dtypes",
|
||||||
|
"TestReduce.test_expand_sums",
|
||||||
|
"TestReduce.test_many_reduction_axes",
|
||||||
|
"TestUpsample.test_torch_upsample",
|
||||||
|
|
||||||
|
# Partition NYI
|
||||||
|
"TestAutograd.test_topk_grad",
|
||||||
|
"TestOps.test_argpartition",
|
||||||
|
"TestOps.test_partition",
|
||||||
|
|
||||||
|
# Block masked matmul NYI
|
||||||
|
"TestBlas.test_block_masked_matmul",
|
||||||
|
|
||||||
|
# Gather matmul NYI
|
||||||
|
"TestBlas.test_gather_matmul",
|
||||||
|
"TestBlas.test_gather_matmul_grad",
|
||||||
|
|
||||||
|
# Scan NYI
|
||||||
|
"TestAutograd.test_cumprod_grad",
|
||||||
|
"TestOps.test_scans",
|
||||||
|
"TestOps.test_logcumsumexp",
|
||||||
|
|
||||||
|
# Hadamard NYI
|
||||||
|
"TestOps.test_hadamard",
|
||||||
|
"TestOps.test_hadamard_grad_vmap",
|
||||||
|
|
||||||
|
# Convolutions NYI
|
||||||
"TestConv.test_1d_conv_with_2d",
|
"TestConv.test_1d_conv_with_2d",
|
||||||
"TestConv.test_asymmetric_padding",
|
"TestConv.test_asymmetric_padding",
|
||||||
"TestConv.test_basic_grad_shapes",
|
"TestConv.test_basic_grad_shapes",
|
||||||
@ -45,11 +89,12 @@ cuda_skip = {
|
|||||||
"TestConvTranspose.test_torch_conv_transpose_3D",
|
"TestConvTranspose.test_torch_conv_transpose_3D",
|
||||||
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
|
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
|
||||||
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
|
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
|
||||||
"TestEinsum.test_attention",
|
|
||||||
"TestEinsum.test_ellipses",
|
|
||||||
"TestEinsum.test_opt_einsum_test_cases",
|
|
||||||
"TestEval.test_multi_output_eval_during_transform",
|
|
||||||
"TestExportImport.test_export_conv",
|
"TestExportImport.test_export_conv",
|
||||||
|
"TestLayers.test_conv1d",
|
||||||
|
"TestLayers.test_conv2d",
|
||||||
|
"TestVmap.test_vmap_conv",
|
||||||
|
|
||||||
|
# FFTs NYI
|
||||||
"TestFFT.test_fft",
|
"TestFFT.test_fft",
|
||||||
"TestFFT.test_fft_big_powers_of_two",
|
"TestFFT.test_fft_big_powers_of_two",
|
||||||
"TestFFT.test_fft_contiguity",
|
"TestFFT.test_fft_contiguity",
|
||||||
@ -59,52 +104,24 @@ cuda_skip = {
|
|||||||
"TestFFT.test_fft_large_numbers",
|
"TestFFT.test_fft_large_numbers",
|
||||||
"TestFFT.test_fft_shared_mem",
|
"TestFFT.test_fft_shared_mem",
|
||||||
"TestFFT.test_fftn",
|
"TestFFT.test_fftn",
|
||||||
"TestInit.test_orthogonal",
|
|
||||||
|
# Lapack ops NYI
|
||||||
"TestLinalg.test_cholesky",
|
"TestLinalg.test_cholesky",
|
||||||
"TestLinalg.test_cholesky_inv",
|
"TestLinalg.test_cholesky_inv",
|
||||||
"TestLinalg.test_eig",
|
"TestLinalg.test_eig",
|
||||||
"TestLinalg.test_eigh",
|
"TestLinalg.test_eigh",
|
||||||
"TestLinalg.test_inverse",
|
"TestLinalg.test_inverse",
|
||||||
|
"TestVmap.test_vmap_inverse",
|
||||||
"TestLinalg.test_lu",
|
"TestLinalg.test_lu",
|
||||||
"TestLinalg.test_lu_factor",
|
"TestLinalg.test_lu_factor",
|
||||||
"TestLinalg.test_pseudo_inverse",
|
"TestLinalg.test_pseudo_inverse",
|
||||||
"TestLinalg.test_qr_factorization",
|
"TestLinalg.test_qr_factorization",
|
||||||
|
"TestInit.test_orthogonal",
|
||||||
"TestLinalg.test_svd_decomposition",
|
"TestLinalg.test_svd_decomposition",
|
||||||
|
"TestVmap.test_vmap_svd",
|
||||||
"TestLinalg.test_tri_inverse",
|
"TestLinalg.test_tri_inverse",
|
||||||
"TestLoad.test_load_f8_e4m3",
|
|
||||||
"TestLosses.test_binary_cross_entropy",
|
# Quantization NYI
|
||||||
"TestMemory.test_memory_info",
|
|
||||||
"TestLayers.test_conv1d",
|
|
||||||
"TestLayers.test_conv2d",
|
|
||||||
"TestLayers.test_elu",
|
|
||||||
"TestLayers.test_group_norm",
|
|
||||||
"TestLayers.test_hard_shrink",
|
|
||||||
"TestLayers.test_pooling",
|
|
||||||
"TestLayers.test_quantized_embedding",
|
|
||||||
"TestLayers.test_sin_pe",
|
|
||||||
"TestLayers.test_softshrink",
|
|
||||||
"TestLayers.test_upsample",
|
|
||||||
"TestOps.test_argpartition",
|
|
||||||
"TestOps.test_array_equal",
|
|
||||||
"TestOps.test_as_strided",
|
|
||||||
"TestOps.test_binary_ops",
|
|
||||||
"TestOps.test_bitwise_grad",
|
|
||||||
"TestOps.test_complex_ops",
|
|
||||||
"TestOps.test_divmod",
|
|
||||||
"TestOps.test_dynamic_slicing",
|
|
||||||
"TestOps.test_hadamard",
|
|
||||||
"TestOps.test_hadamard_grad_vmap",
|
|
||||||
"TestOps.test_irregular_binary_ops",
|
|
||||||
"TestOps.test_kron",
|
|
||||||
"TestOps.test_log1p",
|
|
||||||
"TestOps.test_logaddexp",
|
|
||||||
"TestOps.test_logcumsumexp",
|
|
||||||
"TestOps.test_partition",
|
|
||||||
"TestOps.test_scans",
|
|
||||||
"TestOps.test_softmax",
|
|
||||||
"TestOps.test_sort",
|
|
||||||
"TestOps.test_tensordot",
|
|
||||||
"TestOps.test_tile",
|
|
||||||
"TestQuantized.test_gather_matmul_grad",
|
"TestQuantized.test_gather_matmul_grad",
|
||||||
"TestQuantized.test_gather_qmm",
|
"TestQuantized.test_gather_qmm",
|
||||||
"TestQuantized.test_gather_qmm_sorted",
|
"TestQuantized.test_gather_qmm_sorted",
|
||||||
@ -120,12 +137,4 @@ cuda_skip = {
|
|||||||
"TestQuantized.test_small_matrix",
|
"TestQuantized.test_small_matrix",
|
||||||
"TestQuantized.test_throw",
|
"TestQuantized.test_throw",
|
||||||
"TestQuantized.test_vjp_scales_biases",
|
"TestQuantized.test_vjp_scales_biases",
|
||||||
"TestReduce.test_axis_permutation_sums",
|
|
||||||
"TestReduce.test_dtypes",
|
|
||||||
"TestReduce.test_expand_sums",
|
|
||||||
"TestReduce.test_many_reduction_axes",
|
|
||||||
"TestUpsample.test_torch_upsample",
|
|
||||||
"TestVmap.test_vmap_conv",
|
|
||||||
"TestVmap.test_vmap_inverse",
|
|
||||||
"TestVmap.test_vmap_svd",
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user