mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] Fix back-end bugs and enable corresponding tests (#2296)
* Fix some cuda back-end bugs and enable corresponding tests * more fixes * enable more tests * format
This commit is contained in:
@@ -10,20 +10,13 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
|
||||
using InType = cuda_type_t<CTYPE_IN>; \
|
||||
using OutType = cuda_type_t<CTYPE_OUT>; \
|
||||
if constexpr (cu::CastOp<InType, OutType>::is_castable) { \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
throw std::runtime_error(fmt::format( \
|
||||
"Can not copy data from dtype {} to {}.", \
|
||||
dtype_to_string(out.dtype()), \
|
||||
dtype_to_string(in.dtype()))); \
|
||||
} \
|
||||
}); \
|
||||
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
|
||||
using InType = cuda_type_t<CTYPE_IN>; \
|
||||
using OutType = cuda_type_t<CTYPE_OUT>; \
|
||||
__VA_ARGS__; \
|
||||
}); \
|
||||
})
|
||||
|
||||
void copy_contiguous(
|
||||
|
||||
Reference in New Issue
Block a user