use fp32 for testing, add more complex ops (#2322)

This commit is contained in:
Awni Hannun
2025-07-01 07:30:00 -07:00
committed by GitHub
parent 3d5e17e507
commit dd4f53db63
6 changed files with 68 additions and 40 deletions

View File

@@ -342,8 +342,6 @@ void LayerNormVJP::eval_gpu(
encoder.add_temporary(gw_temp);
}
}
gw.set_data(allocator::malloc(gw.nbytes()));
gb.set_data(allocator::malloc(gb.nbytes()));
// Finish with the gradient for b in case we had a b.
if (gb.ndim() == 1 && gb.size() == axis_size) {