mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Update no copy condition in normalization to account for axis size 1
This commit is contained in:
parent
53fa981caf
commit
ea451af9a0
@ -26,7 +26,7 @@ void RMSNorm::eval_gpu(
|
||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||
if (no_copy && x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1);
|
||||
}
|
||||
if (no_copy) {
|
||||
if (x.is_donatable()) {
|
||||
@ -227,7 +227,7 @@ void LayerNorm::eval_gpu(
|
||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||
if (no_copy && x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1);
|
||||
}
|
||||
if (no_copy) {
|
||||
if (x.is_donatable()) {
|
||||
|
Loading…
Reference in New Issue
Block a user