mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Add softmin, hardshrink, hardtanh (#1180)
--------- Co-authored-by: Nikhil Mehta <nikmehta@tesla.com>
This commit is contained in:
		| @@ -9,7 +9,6 @@ from time_utils import time_fn | ||||
|  | ||||
|  | ||||
| def bench_gelu(): | ||||
|  | ||||
|     def gelu(x): | ||||
|         return x * (1 + mx.erf(x / math.sqrt(2))) / 2 | ||||
|  | ||||
| @@ -51,7 +50,6 @@ def bench_gelu(): | ||||
|  | ||||
|  | ||||
| def bench_layernorm(): | ||||
|  | ||||
|     weight = mx.random.uniform(shape=(4096,)).astype(mx.float16) | ||||
|     bias = mx.random.uniform(shape=(4096,)).astype(mx.float16) | ||||
|     mx.eval(weight, bias) | ||||
|   | ||||
| @@ -54,7 +54,6 @@ def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1): | ||||
|  | ||||
|  | ||||
| def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype): | ||||
|  | ||||
|     scale = 1.0 / math.sqrt(kH * kH * C) | ||||
|     a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype) | ||||
|     b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype( | ||||
|   | ||||
| @@ -35,7 +35,6 @@ def run_bench(system_size): | ||||
|  | ||||
|  | ||||
| def time_fft(): | ||||
|  | ||||
|     with mx.stream(mx.cpu): | ||||
|         cpu_bandwidths = run_bench(system_size=int(2**22)) | ||||
|  | ||||
|   | ||||
| @@ -10,7 +10,6 @@ SEQ_INCREMENT = 50 | ||||
|  | ||||
|  | ||||
| def time_self_attention_primitives(): | ||||
|  | ||||
|     mx.random.seed(3) | ||||
|     B = 2 | ||||
|     H = 38 | ||||
| @@ -32,7 +31,6 @@ def time_self_attention_primitives(): | ||||
|  | ||||
|  | ||||
| def time_self_attention_sdpa(): | ||||
|  | ||||
|     mx.random.seed(3) | ||||
|     B = 2 | ||||
|     H = 38 | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Nikhil Mehta
					Nikhil Mehta