Update no copy condition in normalization to account for axis size 1

This commit is contained in:
Jagrit Digani 2025-06-11 09:58:15 -07:00
parent 53fa981caf
commit ea451af9a0

View File

@ -26,7 +26,7 @@ void RMSNorm::eval_gpu(
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) { if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2]; auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back()); no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1);
} }
if (no_copy) { if (no_copy) {
if (x.is_donatable()) { if (x.is_donatable()) {
@ -227,7 +227,7 @@ void LayerNorm::eval_gpu(
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) { if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2]; auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back()); no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1);
} }
if (no_copy) { if (no_copy) {
if (x.is_donatable()) { if (x.is_donatable()) {