mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-21 18:28:11 +08:00
CUDA backend: unary ops (#2158)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
@@ -99,21 +100,7 @@ void unary_op_gpu(
|
||||
array& out,
|
||||
const std::string op,
|
||||
const Stream& s) {
|
||||
auto& in = inputs[0];
|
||||
bool contig = in.flags().contiguous;
|
||||
if (contig) {
|
||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
set_unary_output_data(inputs[0], out);
|
||||
unary_op_gpu_inplace(inputs, out, op, s);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user