mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-30 13:41:14 +08:00
nit
This commit is contained in:
parent
bd2ea38397
commit
50dcaa6a7c
@ -244,8 +244,7 @@ void LayerNorm::eval_gpu(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
array o = set_output(inputs[0]);
|
const array x = set_output(inputs[0]);
|
||||||
const array& x = o.data_shared_ptr() ? o : out;
|
|
||||||
const array& w = inputs[1];
|
const array& w = inputs[1];
|
||||||
const array& b = inputs[2];
|
const array& b = inputs[2];
|
||||||
|
|
||||||
|
@ -156,7 +156,6 @@ __global__ void rms_norm_vjp(
|
|||||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
// thread_g * thread_w * norm - thread_x * meangwx * norm3
|
|
||||||
float xi = xn[i];
|
float xi = xn[i];
|
||||||
float wi = wn[i];
|
float wi = wn[i];
|
||||||
float gi = gn[i];
|
float gi = gn[i];
|
||||||
@ -214,8 +213,7 @@ void RMSNorm::eval_gpu(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
array o = set_output(inputs[0]);
|
const array x = set_output(inputs[0]);
|
||||||
const array& x = o.data_shared_ptr() ? o : out;
|
|
||||||
const array& w = inputs[1];
|
const array& w = inputs[1];
|
||||||
|
|
||||||
int32_t axis_size = x.shape().back();
|
int32_t axis_size = x.shape().back();
|
||||||
|
Loading…
Reference in New Issue
Block a user