mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Update no copy condition in normalization to account for axis size 1
This commit is contained in:
parent
53fa981caf
commit
ea451af9a0
@ -26,7 +26,7 @@ void RMSNorm::eval_gpu(
|
|||||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||||
if (no_copy && x.ndim() > 1) {
|
if (no_copy && x.ndim() > 1) {
|
||||||
auto s = x.strides()[x.ndim() - 2];
|
auto s = x.strides()[x.ndim() - 2];
|
||||||
no_copy &= (s == 0 || s == x.shape().back());
|
no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1);
|
||||||
}
|
}
|
||||||
if (no_copy) {
|
if (no_copy) {
|
||||||
if (x.is_donatable()) {
|
if (x.is_donatable()) {
|
||||||
@ -227,7 +227,7 @@ void LayerNorm::eval_gpu(
|
|||||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||||
if (no_copy && x.ndim() > 1) {
|
if (no_copy && x.ndim() > 1) {
|
||||||
auto s = x.strides()[x.ndim() - 2];
|
auto s = x.strides()[x.ndim() - 2];
|
||||||
no_copy &= (s == 0 || s == x.shape().back());
|
no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1);
|
||||||
}
|
}
|
||||||
if (no_copy) {
|
if (no_copy) {
|
||||||
if (x.is_donatable()) {
|
if (x.is_donatable()) {
|
||||||
|
Loading…
Reference in New Issue
Block a user