Fix data size bug

This commit is contained in:
Jagrit Digani 2024-11-21 12:54:12 -08:00
parent 4640f865cc
commit ed4fb26cb9

View File

@ -330,6 +330,7 @@ void ScaledDotProductAttention::eval_gpu(
size_t str_oH = o.shape(3);
size_t str_oL = o.shape(1) * str_oH;
size_t str_oB = o.shape(2) * str_oL;
size_t data_size = o.shape(0) * str_oB;
array::Flags flags{
/* bool contiguous = */ 1,
@ -339,7 +340,7 @@ void ScaledDotProductAttention::eval_gpu(
o.set_data(
allocator::malloc_or_wait(o.nbytes()),
o.data_size(),
data_size,
{str_oB, str_oH, str_oL, str_oD},
flags);