mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] Fix back-end bugs and enable corresponding tests (#2296)
* Fix some cuda back-end bugs and enable corresponding tests * more fixes * enable more tests * format
This commit is contained in:
@@ -165,7 +165,7 @@ void binary_op_gpu_inplace(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.data_size(),
|
||||
out.size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(a_strides),
|
||||
const_param<NDIM>(b_strides));
|
||||
@@ -178,7 +178,7 @@ void binary_op_gpu_inplace(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.data_size(),
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
const_param(b_strides),
|
||||
@@ -196,8 +196,8 @@ void binary_op_gpu_inplace(
|
||||
} else if (bopt == BinaryOpType::VectorVector) {
|
||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
||||
}
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, LARGE);
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
@@ -264,7 +264,6 @@ BINARY_GPU(Add)
|
||||
BINARY_GPU(ArcTan2)
|
||||
BINARY_GPU(Divide)
|
||||
BINARY_GPU(Remainder)
|
||||
BINARY_GPU(Equal)
|
||||
BINARY_GPU(Greater)
|
||||
BINARY_GPU(GreaterEqual)
|
||||
BINARY_GPU(Less)
|
||||
@@ -279,6 +278,17 @@ BINARY_GPU(NotEqual)
|
||||
BINARY_GPU(Power)
|
||||
BINARY_GPU(Subtract)
|
||||
|
||||
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Equal::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
auto op = get_primitive_string(this);
|
||||
if (equal_nan_) {
|
||||
binary_op_gpu<cu::NaNEqual>(inputs, out, op, s);
|
||||
} else {
|
||||
binary_op_gpu<cu::Equal>(inputs, out, op, s);
|
||||
}
|
||||
}
|
||||
|
||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
|
||||
Reference in New Issue
Block a user