This commit is contained in:
Awni Hannun 2025-06-12 16:28:02 -07:00
parent bd2ea38397
commit 50dcaa6a7c
2 changed files with 2 additions and 5 deletions

View File

@ -244,8 +244,7 @@ void LayerNorm::eval_gpu(
}
};
array o = set_output(inputs[0]);
const array& x = o.data_shared_ptr() ? o : out;
const array x = set_output(inputs[0]);
const array& w = inputs[1];
const array& b = inputs[2];

View File

@ -156,7 +156,6 @@ __global__ void rms_norm_vjp(
cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) {
// thread_g * thread_w * norm - thread_x * meangwx * norm3
float xi = xn[i];
float wi = wn[i];
float gi = gn[i];
@ -214,8 +213,7 @@ void RMSNorm::eval_gpu(
}
};
array o = set_output(inputs[0]);
const array& x = o.data_shared_ptr() ? o : out;
const array x = set_output(inputs[0]);
const array& w = inputs[1];
int32_t axis_size = x.shape().back();