mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	update bench
This commit is contained in:
		@@ -25,13 +25,13 @@ def sdpa(q, k, v):
 | 
				
			|||||||
    return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None)
 | 
					    return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def quant_sdpa(q, k, v):
 | 
					def quant_sdpa(q, k, v, bits=4):
 | 
				
			||||||
    return mx.fast.quantized_scaled_dot_product_attention(
 | 
					    return mx.fast.quantized_scaled_dot_product_attention(
 | 
				
			||||||
        q, *k, *v, scale=1.0, mask=None, bits=8
 | 
					        q, *k, *v, scale=1.0, mask=None, bits=bits
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def quant_attention(q, k, v):
 | 
					def quant_attention(q, k, v, bits=4):
 | 
				
			||||||
    B, Hq, L, D = q.shape
 | 
					    B, Hq, L, D = q.shape
 | 
				
			||||||
    Hk = k[0].shape[1]
 | 
					    Hk = k[0].shape[1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -39,10 +39,10 @@ def quant_attention(q, k, v):
 | 
				
			|||||||
    k = tree_map(lambda x: mx.expand_dims(x, axis=2), k)
 | 
					    k = tree_map(lambda x: mx.expand_dims(x, axis=2), k)
 | 
				
			||||||
    v = tree_map(lambda x: mx.expand_dims(x, axis=2), v)
 | 
					    v = tree_map(lambda x: mx.expand_dims(x, axis=2), v)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    scores = mx.quantized_matmul(q, *k, transpose=True)
 | 
					    scores = mx.quantized_matmul(q, *k, transpose=True, bits=bits)
 | 
				
			||||||
    scores = mx.softmax(scores, axis=-1)
 | 
					    scores = mx.softmax(scores, axis=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    out = mx.quantized_matmul(scores, *v, transpose=False)
 | 
					    out = mx.quantized_matmul(scores, *v, transpose=False, bits=bits)
 | 
				
			||||||
    out = out.reshape((B, Hq, L, D))
 | 
					    out = out.reshape((B, Hq, L, D))
 | 
				
			||||||
    return out
 | 
					    return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -55,11 +55,11 @@ def time_self_attention_sdpa(q, k, v):
 | 
				
			|||||||
    time_fn(sdpa, q, k, v)
 | 
					    time_fn(sdpa, q, k, v)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def time_self_attention_quant_sdpa(q, k, v):
 | 
					def time_self_attention_quant_sdpa(q, k, v, bits=4):
 | 
				
			||||||
    time_fn(quant_sdpa, q, k, v)
 | 
					    time_fn(quant_sdpa, q, k, v, bits)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def time_self_attention_quant_primitives(q, k, v):
 | 
					def time_self_attention_quant_primitives(q, k, v, bits=4):
 | 
				
			||||||
    time_fn(quant_attention, q, k, v)
 | 
					    time_fn(quant_attention, q, k, v)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -70,11 +70,12 @@ if __name__ == "__main__":
 | 
				
			|||||||
    v = mx.random.uniform(shape=(1, H_k, L, D))
 | 
					    v = mx.random.uniform(shape=(1, H_k, L, D))
 | 
				
			||||||
    mx.eval(q, k, v)
 | 
					    mx.eval(q, k, v)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    k_quant = mx.quantize(k)
 | 
					    bits = 4
 | 
				
			||||||
    v_quant = mx.quantize(v)
 | 
					    k_quant = mx.quantize(k, bits=bits)
 | 
				
			||||||
 | 
					    v_quant = mx.quantize(v, bits=bits)
 | 
				
			||||||
    mx.eval(k_quant, v_quant)
 | 
					    mx.eval(k_quant, v_quant)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    time_self_attention_sdpa(q, k, v)
 | 
					    time_self_attention_sdpa(q, k, v)
 | 
				
			||||||
    time_self_attention_quant_sdpa(q, k_quant, v_quant)
 | 
					    time_self_attention_quant_sdpa(q, k_quant, v_quant, bits)
 | 
				
			||||||
    time_self_attention_primitives(q, k, v)
 | 
					    time_self_attention_primitives(q, k, v)
 | 
				
			||||||
    time_self_attention_quant_primitives(q, k_quant, v_quant)
 | 
					    time_self_attention_quant_primitives(q, k_quant, v_quant, bits)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user