diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index 47efc44d2..45ade0fda 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -264,7 +264,6 @@ BINARY_GPU(Add) BINARY_GPU(ArcTan2) BINARY_GPU(Divide) BINARY_GPU(Remainder) -BINARY_GPU(Equal) BINARY_GPU(Greater) BINARY_GPU(GreaterEqual) BINARY_GPU(Less) @@ -279,6 +278,17 @@ BINARY_GPU(NotEqual) BINARY_GPU(Power) BINARY_GPU(Subtract) +void Equal::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Equal::eval_gpu"); + auto& s = out.primitive().stream(); + auto op = get_primitive_string(this); + if (equal_nan_) { + binary_op_gpu(inputs, out, op, s); + } else { + binary_op_gpu(inputs, out, op, s); + } +} + void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("BitwiseBinary::eval_gpu"); auto& s = out.primitive().stream(); diff --git a/mlx/backend/cuda/copy/copy.cuh b/mlx/backend/cuda/copy/copy.cuh index 0c1eff774..ee5120274 100644 --- a/mlx/backend/cuda/copy/copy.cuh +++ b/mlx/backend/cuda/copy/copy.cuh @@ -15,14 +15,7 @@ namespace mlx::core { MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \ using InType = cuda_type_t; \ using OutType = cuda_type_t; \ - if constexpr (cu::CastOp::is_castable) { \ - __VA_ARGS__; \ - } else { \ - throw std::runtime_error(fmt::format( \ - "Can not copy data from dtype {} to {}.", \ - dtype_to_string(out.dtype()), \ - dtype_to_string(in.dtype()))); \ - } \ + __VA_ARGS__; \ }); \ }) diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index f9d373455..f02898705 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -91,7 +91,7 @@ void unary_op_gpu_inplace( } else { auto [shape, strides] = collapse_contiguous_dims(in); auto [in_begin, in_end] = cu::make_general_iterators( - in_ptr, in.data_size(), shape, strides); + in_ptr, in.size(), shape, strides); thrust::transform(policy, in_begin, in_end, out_ptr, Op()); } } else {