mlx/python
Jagrit Digani 9adcd1a650
Support fused masking in Attention (#1924)
* Update API to allow mask='causal' in fast::sdpa

* Add fallback

* Update steel::AttnParams

* Fix typo

* WIP, basic causal

* Update tests

* Update benchmarking

* Update masking loop limits

* Add bool masking and update tests

* Update additive mask

* Update benchmarks

* Update benchmarks

* Update tests

* Update for bfloat error

* Update early exit

* Add random seed to tests
2025-03-20 11:01:32 -07:00
..
mlx Update smooth_l1_loss in losses.py (#1974) 2025-03-19 20:19:02 -07:00
src Support fused masking in Attention (#1924) 2025-03-20 11:01:32 -07:00
tests Support fused masking in Attention (#1924) 2025-03-20 11:01:32 -07:00