contiguous op / prim (#1612)

This commit is contained in:
Awni Hannun
2024-11-21 19:51:49 -08:00
committed by GitHub
parent 0d5e7716ad
commit dcca0d7477
11 changed files with 104 additions and 25 deletions

View File

@@ -170,6 +170,17 @@ void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
concatenate_gpu(inputs, out, axis_, stream());
}
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous)) {
move_or_copy(in, out);
} else {
copy_gpu(in, out, CopyType::General);
}
}
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}