mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-01 16:58:08 +08:00
Add softmin, hardshrink, hardtanh (#1180)
--------- Co-authored-by: Nikhil Mehta <nikmehta@tesla.com>
This commit is contained in:
@@ -15,7 +15,6 @@ def mlx_primitives_sdpa(q, k, v, scale):
|
||||
|
||||
# SDPA for GQA (n_heads > n_kv_heads, n_kv_heads > 1, n_heads % n_kv_heads == 0)
|
||||
def mlx_primitives_sdpa_with_gqa(q, k, v, scale):
|
||||
|
||||
n_repeats = q.shape[1] // k.shape[1]
|
||||
|
||||
# borrowing kv cache tiling from mlx-examples/llms/mistral/mistral.py
|
||||
@@ -34,7 +33,6 @@ def mlx_primitives_sdpa_with_gqa(q, k, v, scale):
|
||||
|
||||
class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase):
|
||||
def test_fast_sdpa(self):
|
||||
|
||||
# Not yet supported:
|
||||
# * K pre-transposed in kernel, V pre-transposed in kernel
|
||||
np.random.seed(0)
|
||||
|
||||
Reference in New Issue
Block a user