mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Merge b13c7ef8f8
into 4fda5fbdf9
This commit is contained in:
commit
eaf63aa031
@ -264,7 +264,6 @@ BINARY_GPU(Add)
|
||||
BINARY_GPU(ArcTan2)
|
||||
BINARY_GPU(Divide)
|
||||
BINARY_GPU(Remainder)
|
||||
BINARY_GPU(Equal)
|
||||
BINARY_GPU(Greater)
|
||||
BINARY_GPU(GreaterEqual)
|
||||
BINARY_GPU(Less)
|
||||
@ -279,6 +278,17 @@ BINARY_GPU(NotEqual)
|
||||
BINARY_GPU(Power)
|
||||
BINARY_GPU(Subtract)
|
||||
|
||||
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Equal::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
auto op = get_primitive_string(this);
|
||||
if (equal_nan_) {
|
||||
binary_op_gpu<cu::NaNEqual>(inputs, out, op, s);
|
||||
} else {
|
||||
binary_op_gpu<cu::Equal>(inputs, out, op, s);
|
||||
}
|
||||
}
|
||||
|
||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
|
@ -15,14 +15,7 @@ namespace mlx::core {
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
|
||||
using InType = cuda_type_t<CTYPE_IN>; \
|
||||
using OutType = cuda_type_t<CTYPE_OUT>; \
|
||||
if constexpr (cu::CastOp<InType, OutType>::is_castable) { \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
throw std::runtime_error(fmt::format( \
|
||||
"Can not copy data from dtype {} to {}.", \
|
||||
dtype_to_string(out.dtype()), \
|
||||
dtype_to_string(in.dtype()))); \
|
||||
} \
|
||||
__VA_ARGS__; \
|
||||
}); \
|
||||
})
|
||||
|
||||
|
@ -91,7 +91,7 @@ void unary_op_gpu_inplace(
|
||||
} else {
|
||||
auto [shape, strides] = collapse_contiguous_dims(in);
|
||||
auto [in_begin, in_end] = cu::make_general_iterators<int64_t>(
|
||||
in_ptr, in.data_size(), shape, strides);
|
||||
in_ptr, in.size(), shape, strides);
|
||||
thrust::transform(policy, in_begin, in_end, out_ptr, Op());
|
||||
}
|
||||
} else {
|
||||
|
Loading…
Reference in New Issue
Block a user