Layer norm grad fix donation bug (#941)

* add layer norm grad test

* Fix donation bug in layernorm vjp

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Angelos Katharopoulos
2024-04-01 06:15:50 -07:00
committed by GitHub
parent 9cbff5ec1d
commit 110d9b149d
2 changed files with 23 additions and 1 deletions

View File

@@ -355,7 +355,14 @@ void LayerNormVJP::eval_gpu(
ReductionPlan plan(
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
strided_reduce_general_dispatch(
g, gb, "sum", plan, {0}, compute_encoder, d, s);
g_in_gx ? gx : (g_in_gw ? gw_temp : g),
gb,
"sum",
plan,
{0},
compute_encoder,
d,
s);
}
const int simd_size = 32;