CUDA backend: unary ops (#2158)

This commit is contained in:
Cheng
2025-06-09 22:45:08 +09:00
committed by GitHub
parent 5866b3857b
commit f8bad60609
13 changed files with 1074 additions and 70 deletions

View File

@@ -1,5 +1,6 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/common/unary.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
@@ -99,21 +100,7 @@ void unary_op_gpu(
array& out,
const std::string op,
const Stream& s) {
auto& in = inputs[0];
bool contig = in.flags().contiguous;
if (contig) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
set_unary_output_data(inputs[0], out);
unary_op_gpu_inplace(inputs, out, op, s);
}