[CUDA] RMSNorm and VJP (#2280)

* rms norm start

* nit
This commit is contained in:
Awni Hannun
2025-06-12 17:09:49 -07:00
committed by GitHub
parent a4fc671d3e
commit 918761a25a
4 changed files with 345 additions and 4 deletions

View File

@@ -244,8 +244,7 @@ void LayerNorm::eval_gpu(
}
};
array o = set_output(inputs[0]);
const array& x = o.data_shared_ptr() ? o : out;
const array x = set_output(inputs[0]);
const array& w = inputs[1];
const array& b = inputs[2];