use fp32 for testing, add more complex ops (#2322)

This commit is contained in:
Awni Hannun
2025-07-01 07:30:00 -07:00
committed by GitHub
parent 3d5e17e507
commit dd4f53db63
6 changed files with 68 additions and 40 deletions

View File

@@ -1,25 +1,15 @@
cuda_skip = {
"TestArray.test_api",
"TestBF16.test_arg_reduction_ops",
"TestBlas.test_complex_gemm",
"TestEinsum.test_ellipses",
"TestEinsum.test_opt_einsum_test_cases",
"TestLoad.test_load_f8_e4m3",
"TestLayers.test_group_norm",
"TestLayers.test_pooling",
"TestLayers.test_quantized_embedding",
"TestLayers.test_sin_pe",
"TestLayers.test_upsample",
"TestOps.test_complex_ops",
"TestOps.test_dynamic_slicing",
"TestReduce.test_dtypes",
"TestUpsample.test_torch_upsample",
# Block masked matmul NYI
"TestBlas.test_block_masked_matmul",
# Gather matmul NYI
"TestBlas.test_gather_matmul",
"TestBlas.test_gather_matmul_grad",
# Scan NYI
"TestArray.test_api",
"TestAutograd.test_cumprod_grad",
"TestOps.test_scans",
"TestOps.test_logcumsumexp",

View File

@@ -1,6 +1,10 @@
# Copyright © 2023 Apple Inc.
import os
# Use regular fp32 precision for tests
os.environ["MLX_ENABLE_TF32"] = "0"
import platform
import unittest
from typing import Any, Callable, List, Tuple, Union