mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-29 04:51:13 +08:00
nit
This commit is contained in:
parent
bd2ea38397
commit
50dcaa6a7c
@ -244,8 +244,7 @@ void LayerNorm::eval_gpu(
|
||||
}
|
||||
};
|
||||
|
||||
array o = set_output(inputs[0]);
|
||||
const array& x = o.data_shared_ptr() ? o : out;
|
||||
const array x = set_output(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
const array& b = inputs[2];
|
||||
|
||||
|
@ -156,7 +156,6 @@ __global__ void rms_norm_vjp(
|
||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
// thread_g * thread_w * norm - thread_x * meangwx * norm3
|
||||
float xi = xn[i];
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
@ -214,8 +213,7 @@ void RMSNorm::eval_gpu(
|
||||
}
|
||||
};
|
||||
|
||||
array o = set_output(inputs[0]);
|
||||
const array& x = o.data_shared_ptr() ? o : out;
|
||||
const array x = set_output(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
|
||||
int32_t axis_size = x.shape().back();
|
||||
|
Loading…
Reference in New Issue
Block a user