mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-02 16:56:46 +08:00
Fix data size bug
This commit is contained in:
parent
4640f865cc
commit
ed4fb26cb9
@ -330,6 +330,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
size_t str_oH = o.shape(3);
|
||||
size_t str_oL = o.shape(1) * str_oH;
|
||||
size_t str_oB = o.shape(2) * str_oL;
|
||||
size_t data_size = o.shape(0) * str_oB;
|
||||
|
||||
array::Flags flags{
|
||||
/* bool contiguous = */ 1,
|
||||
@ -339,7 +340,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
|
||||
o.set_data(
|
||||
allocator::malloc_or_wait(o.nbytes()),
|
||||
o.data_size(),
|
||||
data_size,
|
||||
{str_oB, str_oH, str_oL, str_oD},
|
||||
flags);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user