mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	| @@ -682,7 +682,7 @@ class TestBlas(mlx_tests.MLXTestCase): | ||||
|         self.assertEqual(c.shape, (0, 0)) | ||||
|  | ||||
|     def test_block_masked_matmul(self): | ||||
|         def np_block_masked_mm( | ||||
|         def ref_block_masked_mm( | ||||
|             a, b, block_size, out_mask=None, lhs_mask=None, rhs_mask=None | ||||
|         ): | ||||
|             # Get mask adjusted shapes | ||||
| @@ -690,33 +690,81 @@ class TestBlas(mlx_tests.MLXTestCase): | ||||
|             N = b.shape[-1] | ||||
|             K = a.shape[-1] | ||||
|  | ||||
|             bsx_shape = np.broadcast_shapes(a.shape[:-2], b.shape[:-2]) | ||||
|  | ||||
|             # Expand mask dims | ||||
|             def expand_mask(mask, block_size, Y, X): | ||||
|                 mask = np.expand_dims(mask, (-3, -1)) | ||||
|                 mask_shape = list(mask.shape) | ||||
|                 mask = mx.expand_dims(mask, (-3, -1)) | ||||
|                 mask_shape = list(bsx_shape) + list(mask.shape[-4:]) | ||||
|                 mask_shape[-1] = block_size | ||||
|                 x = mask_shape[-2] * block_size | ||||
|                 mask_shape[-3] = block_size | ||||
|                 y = mask_shape[-4] * block_size | ||||
|                 mask = np.broadcast_to(mask, mask_shape) | ||||
|                 mask = mx.broadcast_to(mask, mask_shape) | ||||
|                 mask_shape = mask_shape[:-4] + [y, x] | ||||
|                 return mask.reshape(mask_shape)[..., :Y, :X] | ||||
|  | ||||
|             a_masked = a | ||||
|             b_masked = b | ||||
|  | ||||
|             if lhs_mask is not None: | ||||
|                 lhs_mask = expand_mask(lhs_mask, block_size, M, K) | ||||
|                 a = lhs_mask * a | ||||
|                 lhs_mask = expand_mask(lhs_mask, block_size, M, K).astype(mx.float32) | ||||
|                 a_masked = lhs_mask * a_masked | ||||
|  | ||||
|             if rhs_mask is not None: | ||||
|                 rhs_mask = expand_mask(rhs_mask, block_size, K, N) | ||||
|                 b = rhs_mask * b | ||||
|                 rhs_mask = expand_mask(rhs_mask, block_size, K, N).astype(mx.float32) | ||||
|                 b_masked = rhs_mask * b_masked | ||||
|  | ||||
|             out = a @ b | ||||
|             out = a_masked @ b_masked | ||||
|  | ||||
|             if out_mask is not None: | ||||
|                 out_mask = expand_mask(out_mask, block_size, M, N) | ||||
|                 out_mask = expand_mask(out_mask, block_size, M, N).astype(mx.float32) | ||||
|                 out = out * out_mask | ||||
|             return out | ||||
|  | ||||
|         def run_test(a, b, block_size, out_mask, a_mask, b_mask, cotan): | ||||
|             def f_ref(a_, b_): | ||||
|                 return ref_block_masked_mm(a_, b_, block_size, out_mask, a_mask, b_mask) | ||||
|  | ||||
|             def f_test(a_, b_): | ||||
|                 return mx.block_masked_mm(a_, b_, block_size, out_mask, a_mask, b_mask) | ||||
|  | ||||
|             out_ref, dout_ref = mx.vjp(f_ref, [a, b], [cotan]) | ||||
|             out_test, dout_test = mx.vjp(f_test, [a, b], [cotan]) | ||||
|  | ||||
|             mx.eval((out_ref, dout_ref, out_test, dout_test)) | ||||
|  | ||||
|             self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-5).item()) | ||||
|  | ||||
|         def run_test_mask_vjp(a, b, block_size, out_mask, a_mask, b_mask, cotan): | ||||
|             def f_ref(a_, b_, a_mask_, b_mask_): | ||||
|                 return ref_block_masked_mm( | ||||
|                     a_, b_, block_size, out_mask, a_mask_, b_mask_ | ||||
|                 ) | ||||
|  | ||||
|             def f_test(a_, b_, a_mask_, b_mask_): | ||||
|                 return mx.block_masked_mm( | ||||
|                     a_, b_, block_size, out_mask, a_mask_, b_mask_ | ||||
|                 ) | ||||
|  | ||||
|             out_ref, dout_ref = mx.vjp(f_ref, [a, b, a_mask, b_mask], [cotan]) | ||||
|             out_test, dout_test = mx.vjp(f_test, [a, b, a_mask, b_mask], [cotan]) | ||||
|  | ||||
|             mx.eval((out_ref, dout_ref, out_test, dout_test)) | ||||
|  | ||||
|             self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-5).item()) | ||||
|  | ||||
|             for r, t in zip(dout_ref, dout_test): | ||||
|                 self.assertEqual(r.shape, t.shape) | ||||
|                 self.assertTrue(mx.allclose(r, t, atol=1e-4).item()) | ||||
|  | ||||
|         def make_mask(tm_, tn_, batch, np_dtype): | ||||
|             arr_np_mask = np.random.normal(size=batch + (tm_, tn_)).astype(np_dtype) | ||||
|             arr_np_bool_mask = arr_np_mask < 0.0 | ||||
|             arr_np_mask[arr_np_bool_mask] = 0.0 | ||||
|  | ||||
|             return mx.array(arr_np_bool_mask), mx.array(arr_np_mask) | ||||
|  | ||||
|         def test_shape( | ||||
|             M, | ||||
|             N, | ||||
| @@ -737,49 +785,49 @@ class TestBlas(mlx_tests.MLXTestCase): | ||||
|                 batch_A=batch_A, | ||||
|                 batch_B=batch_B, | ||||
|             ): | ||||
|                 tm = (M + block_size - 1) // block_size | ||||
|                 tn = (N + block_size - 1) // block_size | ||||
|                 tk = (K + block_size - 1) // block_size | ||||
|                 batch_out = np.broadcast_shapes(batch_A, batch_B) | ||||
|                 cotan = mx.ones(batch_out + (M, N)) | ||||
|  | ||||
|                 a_np = np.random.normal(size=batch_A + (M, K)).astype(np_dtype) | ||||
|                 b_np = np.random.normal(size=batch_B + (K, N)).astype(np_dtype) | ||||
|  | ||||
|                 batch_out = np.broadcast_shapes(batch_A, batch_B) | ||||
|                 a_mx = mx.array(a_np) | ||||
|                 b_mx = mx.array(b_np) | ||||
|  | ||||
|                 a_np_mask = np.random.normal(size=batch_A + (tm, tk)) < 0.0 | ||||
|                 b_np_mask = np.random.normal(size=batch_B + (tk, tn)) < 0.0 | ||||
|                 out_np_mask = np.random.normal(size=batch_out + (tm, tn)) < 0.0 | ||||
|                 tm = (M + block_size - 1) // block_size | ||||
|                 tn = (N + block_size - 1) // block_size | ||||
|                 tk = (K + block_size - 1) // block_size | ||||
|  | ||||
|                 a_mx, b_mx, a_mx_mask, b_mx_mask, out_mx_mask = map( | ||||
|                     mx.array, (a_np, b_np, a_np_mask, b_np_mask, out_np_mask) | ||||
|                 a_mx_bool_mask, a_mx_mask = make_mask(tm, tk, batch_A, np_dtype) | ||||
|                 b_mx_bool_mask, b_mx_mask = make_mask(tk, tn, batch_B, np_dtype) | ||||
|                 out_mx_bool_mask, out_mx_mask = make_mask(tm, tn, batch_out, np_dtype) | ||||
|  | ||||
|                 # Boolean block masks | ||||
|                 run_test( | ||||
|                     a_mx, | ||||
|                     b_mx, | ||||
|                     block_size, | ||||
|                     out_mx_bool_mask, | ||||
|                     a_mx_bool_mask, | ||||
|                     b_mx_bool_mask, | ||||
|                     cotan, | ||||
|                 ) | ||||
|                 run_test(a_mx, b_mx, block_size, out_mx_bool_mask, None, None, cotan) | ||||
|                 run_test( | ||||
|                     a_mx, b_mx, block_size, None, a_mx_bool_mask, b_mx_bool_mask, cotan | ||||
|                 ) | ||||
|  | ||||
|                 if transpose: | ||||
|                     b_np = np.random.normal(size=batch_B + (N, K)).astype(np_dtype) | ||||
|                     b_mx = mx.array(b_np) | ||||
|  | ||||
|                     b_np = np.swapaxes(b_np, -2, -1) | ||||
|                     b_mx = mx.swapaxes(b_mx, -2, -1) | ||||
|  | ||||
|                 out_np = np_block_masked_mm( | ||||
|                     a_np, b_np, block_size, out_np_mask, a_np_mask, b_np_mask | ||||
|                 # Float block masks | ||||
|                 run_test( | ||||
|                     a_mx, b_mx, block_size, out_mx_mask, a_mx_mask, b_mx_mask, cotan | ||||
|                 ) | ||||
|                 out_mx = mx.block_masked_mm( | ||||
|                     a_mx, b_mx, block_size, out_mx_mask, a_mx_mask, b_mx_mask | ||||
|                 run_test(a_mx, b_mx, block_size, None, a_mx_mask, b_mx_mask, cotan) | ||||
|                 run_test_mask_vjp( | ||||
|                     a_mx, b_mx, block_size, out_mx_mask, a_mx_mask, b_mx_mask, cotan | ||||
|                 ) | ||||
|                 self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5)) | ||||
|  | ||||
|                 out_np = np_block_masked_mm(a_np, b_np, block_size, out_np_mask) | ||||
|                 out_mx = mx.block_masked_mm(a_mx, b_mx, block_size, out_mx_mask) | ||||
|                 self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5)) | ||||
|  | ||||
|                 out_np = np_block_masked_mm( | ||||
|                     a_np, b_np, block_size, None, a_np_mask, b_np_mask | ||||
|                 run_test_mask_vjp( | ||||
|                     a_mx, b_mx, block_size, None, a_mx_mask, b_mx_mask, cotan | ||||
|                 ) | ||||
|                 out_mx = mx.block_masked_mm( | ||||
|                     a_mx, b_mx, block_size, None, a_mx_mask, b_mx_mask | ||||
|                 ) | ||||
|                 self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5)) | ||||
|  | ||||
|         shapes = ( | ||||
|             (16, 16, 16, 32), | ||||
| @@ -789,11 +837,10 @@ class TestBlas(mlx_tests.MLXTestCase): | ||||
|         ) | ||||
|  | ||||
|         for M, N, K, block_size in shapes: | ||||
|             test_shape(M, N, K, block_size, transpose=False) | ||||
|             test_shape(M, N, K, block_size, transpose=True) | ||||
|             test_shape(M, N, K, block_size) | ||||
|  | ||||
|         # Test broadcasting | ||||
|         test_shape(64, 64, 64, 32, transpose=False, batch_A=(1, 2), batch_B=(2, 2)) | ||||
|         test_shape(64, 64, 64, 32, batch_A=(1, 2), batch_B=(2, 2)) | ||||
|  | ||||
|         # Test gemv | ||||
|         a_np = np.random.normal(size=(64, 64)).astype(np.float32) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jagrit Digani
					Jagrit Digani