mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
parent
08a1bf3f10
commit
00794c42bc
@ -172,6 +172,7 @@ void sdpa_vector(
|
|||||||
std::string hash_name = kname;
|
std::string hash_name = kname;
|
||||||
hash_name += has_mask ? "_mask" : "_nomask";
|
hash_name += has_mask ? "_mask" : "_nomask";
|
||||||
hash_name += query_transposed ? "_qt" : "_qnt";
|
hash_name += query_transposed ? "_qt" : "_qnt";
|
||||||
|
hash_name += do_causal ? "_c" : "_nc";
|
||||||
|
|
||||||
// Get the kernel
|
// Get the kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
@ -268,6 +269,7 @@ void sdpa_vector_2pass(
|
|||||||
std::string hash_name = kname;
|
std::string hash_name = kname;
|
||||||
hash_name += has_mask ? "_mask" : "_nomask";
|
hash_name += has_mask ? "_mask" : "_nomask";
|
||||||
hash_name += query_transposed ? "_qt" : "_qnt";
|
hash_name += query_transposed ? "_qt" : "_qnt";
|
||||||
|
hash_name += do_causal ? "_c" : "_nc";
|
||||||
|
|
||||||
// Get the kernel
|
// Get the kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
@ -347,6 +347,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
masks = [
|
masks = [
|
||||||
|
None,
|
||||||
mx.array(True),
|
mx.array(True),
|
||||||
mx.array([True] * (L - 10) + [False] * 10),
|
mx.array([True] * (L - 10) + [False] * 10),
|
||||||
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
||||||
@ -392,7 +393,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
|||||||
def test_fast_sdpa_few_query(self):
|
def test_fast_sdpa_few_query(self):
|
||||||
D = 64
|
D = 64
|
||||||
L = 43
|
L = 43
|
||||||
Lq = 4
|
Lq = 8
|
||||||
Nq = 8
|
Nq = 8
|
||||||
Nkv = 1
|
Nkv = 1
|
||||||
scale = 1.0
|
scale = 1.0
|
||||||
@ -403,6 +404,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
|||||||
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||||
|
|
||||||
masks = [
|
masks = [
|
||||||
|
None,
|
||||||
mx.array(True),
|
mx.array(True),
|
||||||
mx.array([True] * (L - 10) + [False] * 10),
|
mx.array([True] * (L - 10) + [False] * 10),
|
||||||
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
||||||
@ -428,6 +430,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
|||||||
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||||
|
|
||||||
masks = [
|
masks = [
|
||||||
|
None,
|
||||||
mx.array(True),
|
mx.array(True),
|
||||||
mx.array([True] * (L - 10) + [False] * 10),
|
mx.array([True] * (L - 10) + [False] * 10),
|
||||||
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
||||||
|
Loading…
Reference in New Issue
Block a user