Add contiguous_copy_gpu util for copying array (#2379)

This commit is contained in:
Cheng
2025-07-18 22:44:25 +09:00
committed by GitHub
parent 31fc530c76
commit 45adec102c
20 changed files with 40 additions and 67 deletions

View File

@@ -206,8 +206,7 @@ void RMSNorm::eval_gpu(
}
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
@@ -259,9 +258,7 @@ void RMSNormVJP::eval_gpu(
return x;
}
copied = true;
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
return x_copy;
return contiguous_copy_gpu(x, s);
};
bool donate_x = inputs[0].is_donatable();
bool donate_g = inputs[2].is_donatable();