more bug fixes

This commit is contained in:
Awni Hannun 2025-06-16 09:35:58 -07:00
parent c552ff2451
commit 7429613f76
7 changed files with 91 additions and 62 deletions

View File

@ -150,10 +150,10 @@ void binary_op_gpu_inplace(
auto [shape, strides] = collapse_contiguous_dims(a, b, out); auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& a_strides = strides[0]; auto& a_strides = strides[0];
auto& b_strides = strides[1]; auto& b_strides = strides[1];
bool large = a.data_size() > UINT32_MAX || bool large = a.data_size() > INT32_MAX ||
b.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX; b.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, { MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>; using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size(); int ndim = shape.size();
if (ndim <= 3) { if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, { MLX_SWITCH_1_2_3(ndim, NDIM, {

View File

@ -130,11 +130,13 @@ struct FusedKernelBuilder {
constexpr const char* g_jit_includes = R"( constexpr const char* g_jit_includes = R"(
#include "mlx/backend/cuda/device/binary_ops.cuh" #include "mlx/backend/cuda/device/binary_ops.cuh"
#include "mlx/backend/cuda/device/ternary_ops.cuh"
#include "mlx/backend/cuda/device/unary_ops.cuh" #include "mlx/backend/cuda/device/unary_ops.cuh"
#include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/device/utils.cuh"
#include <cooperative_groups.h> #include <cooperative_groups.h>
#define inf cuda::std::numeric_limits<float>::infinity()
)"; )";
void Compiled::eval_gpu( void Compiled::eval_gpu(

View File

@ -1,4 +1,5 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#pragma once
namespace mlx::core::cu { namespace mlx::core::cu {

View File

@ -336,4 +336,21 @@ struct LoopedElemToLoc<1, false, OffsetT> {
} }
}; };
inline __device__ cuComplex log1p(cuComplex in) {
float x = cuCrealf(in);
float y = cuCimagf(in);
float zabs = sqrt(x * x + y * y);
float theta = atan2f(y, x + 1);
if (zabs < 0.5f) {
float r = x * (2 + x) + y * y;
if (r == 0) { // handle underflow
return {x, theta};
}
return {0.5f * log1pf(r), theta};
} else {
auto z0 = sqrt((x + 1) * (x + 1) + y * y);
return {log(z0), theta};
}
}
} // namespace mlx::core::cu } // namespace mlx::core::cu

View File

@ -101,10 +101,10 @@ void ternary_op_gpu_inplace(
auto& a_strides = strides[0]; auto& a_strides = strides[0];
auto& b_strides = strides[1]; auto& b_strides = strides[1];
auto& c_strides = strides[2]; auto& c_strides = strides[2];
bool large = a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX || bool large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
c.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX; c.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, { MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>; using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size(); int ndim = shape.size();
if (ndim <= 3) { if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, { MLX_SWITCH_1_2_3(ndim, NDIM, {

View File

@ -27,13 +27,13 @@ constexpr bool supports_unary_op() {
std::is_same_v<Op, ArcSin> || std::is_same_v<Op, ArcSinh> || std::is_same_v<Op, ArcSin> || std::is_same_v<Op, ArcSinh> ||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> || std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> || std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Log1p> || std::is_same_v<Op, Expm1> ||
std::is_same_v<Op, Sigmoid> || std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Sigmoid> || std::is_same_v<Op, Sqrt> ||
std::is_same_v<Op, Rsqrt>) { std::is_same_v<Op, Rsqrt>) {
return std::is_same_v<In, Out> && is_floating_v<In>; return std::is_same_v<In, Out> && is_floating_v<In>;
} }
if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> || if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
std::is_same_v<Op, Log10>) { std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p>) {
return std::is_same_v<In, Out> && is_inexact_v<In>; return std::is_same_v<In, Out> && is_inexact_v<In>;
} }
if (std::is_same_v<Op, BitwiseInvert>) { if (std::is_same_v<Op, BitwiseInvert>) {

View File

@ -1,24 +1,68 @@
cuda_skip = { cuda_skip = {
"TestArray.test_api", "TestArray.test_api",
"TestAutograd.test_cumprod_grad",
"TestAutograd.test_slice_grads", "TestAutograd.test_slice_grads",
"TestAutograd.test_split_against_slice", "TestAutograd.test_split_against_slice",
"TestAutograd.test_stop_gradient", "TestAutograd.test_stop_gradient",
"TestAutograd.test_topk_grad",
"TestAutograd.test_update_state", "TestAutograd.test_update_state",
"TestAutograd.test_vjp", "TestAutograd.test_vjp",
"TestBF16.test_arg_reduction_ops", "TestBF16.test_arg_reduction_ops",
"TestBF16.test_binary_ops", "TestBF16.test_binary_ops",
"TestBF16.test_reduction_ops", "TestBF16.test_reduction_ops",
"TestBlas.test_block_masked_matmul",
"TestBlas.test_complex_gemm", "TestBlas.test_complex_gemm",
"TestBlas.test_gather_matmul",
"TestBlas.test_gather_matmul_grad",
"TestBlas.test_matmul_batched", "TestBlas.test_matmul_batched",
"TestBlas.test_matrix_vector_attn", "TestBlas.test_matrix_vector_attn",
"TestCompile.test_compile_dynamic_dims", "TestCompile.test_compile_dynamic_dims",
"TestCompile.test_compile_inf", "TestEinsum.test_attention",
"TestCompile.test_inf_constant", "TestEinsum.test_ellipses",
"TestEinsum.test_opt_einsum_test_cases",
"TestEval.test_multi_output_eval_during_transform",
"TestLoad.test_load_f8_e4m3",
"TestLosses.test_binary_cross_entropy",
"TestMemory.test_memory_info",
"TestLayers.test_group_norm",
"TestLayers.test_pooling",
"TestLayers.test_quantized_embedding",
"TestLayers.test_sin_pe",
"TestLayers.test_upsample",
"TestOps.test_array_equal",
"TestOps.test_complex_ops",
"TestOps.test_divmod",
"TestOps.test_dynamic_slicing",
"TestOps.test_irregular_binary_ops",
"TestOps.test_kron",
"TestOps.test_logaddexp",
"TestOps.test_softmax",
"TestOps.test_sort",
"TestOps.test_tensordot",
"TestOps.test_tile",
"TestReduce.test_axis_permutation_sums",
"TestReduce.test_dtypes",
"TestReduce.test_expand_sums",
"TestReduce.test_many_reduction_axes",
"TestUpsample.test_torch_upsample",
# Partition NYI
"TestAutograd.test_topk_grad",
"TestOps.test_argpartition",
"TestOps.test_partition",
# Block masked matmul NYI
"TestBlas.test_block_masked_matmul",
# Gather matmul NYI
"TestBlas.test_gather_matmul",
"TestBlas.test_gather_matmul_grad",
# Scan NYI
"TestAutograd.test_cumprod_grad",
"TestOps.test_scans",
"TestOps.test_logcumsumexp",
# Hadamard NYI
"TestOps.test_hadamard",
"TestOps.test_hadamard_grad_vmap",
# Convolutions NYI
"TestConv.test_1d_conv_with_2d", "TestConv.test_1d_conv_with_2d",
"TestConv.test_asymmetric_padding", "TestConv.test_asymmetric_padding",
"TestConv.test_basic_grad_shapes", "TestConv.test_basic_grad_shapes",
@ -45,11 +89,12 @@ cuda_skip = {
"TestConvTranspose.test_torch_conv_transpose_3D", "TestConvTranspose.test_torch_conv_transpose_3D",
"TestConvTranspose.test_torch_conv_transpose_3D_grad", "TestConvTranspose.test_torch_conv_transpose_3D_grad",
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding", "TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
"TestEinsum.test_attention",
"TestEinsum.test_ellipses",
"TestEinsum.test_opt_einsum_test_cases",
"TestEval.test_multi_output_eval_during_transform",
"TestExportImport.test_export_conv", "TestExportImport.test_export_conv",
"TestLayers.test_conv1d",
"TestLayers.test_conv2d",
"TestVmap.test_vmap_conv",
# FFTs NYI
"TestFFT.test_fft", "TestFFT.test_fft",
"TestFFT.test_fft_big_powers_of_two", "TestFFT.test_fft_big_powers_of_two",
"TestFFT.test_fft_contiguity", "TestFFT.test_fft_contiguity",
@ -59,52 +104,24 @@ cuda_skip = {
"TestFFT.test_fft_large_numbers", "TestFFT.test_fft_large_numbers",
"TestFFT.test_fft_shared_mem", "TestFFT.test_fft_shared_mem",
"TestFFT.test_fftn", "TestFFT.test_fftn",
"TestInit.test_orthogonal",
# Lapack ops NYI
"TestLinalg.test_cholesky", "TestLinalg.test_cholesky",
"TestLinalg.test_cholesky_inv", "TestLinalg.test_cholesky_inv",
"TestLinalg.test_eig", "TestLinalg.test_eig",
"TestLinalg.test_eigh", "TestLinalg.test_eigh",
"TestLinalg.test_inverse", "TestLinalg.test_inverse",
"TestVmap.test_vmap_inverse",
"TestLinalg.test_lu", "TestLinalg.test_lu",
"TestLinalg.test_lu_factor", "TestLinalg.test_lu_factor",
"TestLinalg.test_pseudo_inverse", "TestLinalg.test_pseudo_inverse",
"TestLinalg.test_qr_factorization", "TestLinalg.test_qr_factorization",
"TestInit.test_orthogonal",
"TestLinalg.test_svd_decomposition", "TestLinalg.test_svd_decomposition",
"TestVmap.test_vmap_svd",
"TestLinalg.test_tri_inverse", "TestLinalg.test_tri_inverse",
"TestLoad.test_load_f8_e4m3",
"TestLosses.test_binary_cross_entropy", # Quantization NYI
"TestMemory.test_memory_info",
"TestLayers.test_conv1d",
"TestLayers.test_conv2d",
"TestLayers.test_elu",
"TestLayers.test_group_norm",
"TestLayers.test_hard_shrink",
"TestLayers.test_pooling",
"TestLayers.test_quantized_embedding",
"TestLayers.test_sin_pe",
"TestLayers.test_softshrink",
"TestLayers.test_upsample",
"TestOps.test_argpartition",
"TestOps.test_array_equal",
"TestOps.test_as_strided",
"TestOps.test_binary_ops",
"TestOps.test_bitwise_grad",
"TestOps.test_complex_ops",
"TestOps.test_divmod",
"TestOps.test_dynamic_slicing",
"TestOps.test_hadamard",
"TestOps.test_hadamard_grad_vmap",
"TestOps.test_irregular_binary_ops",
"TestOps.test_kron",
"TestOps.test_log1p",
"TestOps.test_logaddexp",
"TestOps.test_logcumsumexp",
"TestOps.test_partition",
"TestOps.test_scans",
"TestOps.test_softmax",
"TestOps.test_sort",
"TestOps.test_tensordot",
"TestOps.test_tile",
"TestQuantized.test_gather_matmul_grad", "TestQuantized.test_gather_matmul_grad",
"TestQuantized.test_gather_qmm", "TestQuantized.test_gather_qmm",
"TestQuantized.test_gather_qmm_sorted", "TestQuantized.test_gather_qmm_sorted",
@ -120,12 +137,4 @@ cuda_skip = {
"TestQuantized.test_small_matrix", "TestQuantized.test_small_matrix",
"TestQuantized.test_throw", "TestQuantized.test_throw",
"TestQuantized.test_vjp_scales_biases", "TestQuantized.test_vjp_scales_biases",
"TestReduce.test_axis_permutation_sums",
"TestReduce.test_dtypes",
"TestReduce.test_expand_sums",
"TestReduce.test_many_reduction_axes",
"TestUpsample.test_torch_upsample",
"TestVmap.test_vmap_conv",
"TestVmap.test_vmap_inverse",
"TestVmap.test_vmap_svd",
} }