mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 02:28:13 +08:00 
			
		
		
		
	MoE backward improvements (#2335)
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							a4fcc893cd
						
					
				
				
					commit
					4a9b29a875
				
			@@ -1163,6 +1163,99 @@ class TestBlas(mlx_tests.MLXTestCase):
 | 
			
		||||
            self.assertEqual(r.shape, t.shape)
 | 
			
		||||
            self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
 | 
			
		||||
 | 
			
		||||
    def test_gather_mm_sorted(self):
 | 
			
		||||
        def gather_mm_ref(a, b, rhs):
 | 
			
		||||
            b = b[rhs]
 | 
			
		||||
            return a @ b
 | 
			
		||||
 | 
			
		||||
        def gather_mm_test(a, b, rhs):
 | 
			
		||||
            return mx.gather_mm(a, b, rhs_indices=rhs, sorted_indices=True)
 | 
			
		||||
 | 
			
		||||
        a = mx.random.normal((100, 1, 100))
 | 
			
		||||
        b = mx.random.normal((8, 100, 100))
 | 
			
		||||
        rhs = mx.sort(mx.random.randint(0, 8, shape=(100,)))
 | 
			
		||||
 | 
			
		||||
        c1 = gather_mm_ref(a, b, rhs)
 | 
			
		||||
        c2 = gather_mm_test(a, b, rhs)
 | 
			
		||||
        self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
 | 
			
		||||
 | 
			
		||||
        cotan = mx.random.normal(c1.shape)
 | 
			
		||||
        c1, dc1 = mx.vjp(
 | 
			
		||||
            lambda a, b: gather_mm_ref(a, b, rhs),
 | 
			
		||||
            [a, b],
 | 
			
		||||
            [cotan],
 | 
			
		||||
        )
 | 
			
		||||
        c2, dc2 = mx.vjp(
 | 
			
		||||
            lambda a, b: gather_mm_test(a, b, rhs),
 | 
			
		||||
            [a, b],
 | 
			
		||||
            [cotan],
 | 
			
		||||
        )
 | 
			
		||||
        self.assertTrue(mx.allclose(c1[0], c2[0], atol=1e-4))
 | 
			
		||||
        self.assertTrue(mx.allclose(dc1[0], dc2[0], atol=1e-4))
 | 
			
		||||
        self.assertTrue(mx.allclose(dc1[1], dc2[1], atol=1e-4))
 | 
			
		||||
 | 
			
		||||
    def test_segmented_mm(self):
 | 
			
		||||
        def segmented_mm_ref(a, b, s):
 | 
			
		||||
            s = s.tolist()
 | 
			
		||||
            c = []
 | 
			
		||||
            for s1, s2 in s:
 | 
			
		||||
                c.append(a[:, s1:s2] @ b[s1:s2, :])
 | 
			
		||||
            return mx.stack(c, axis=0)
 | 
			
		||||
 | 
			
		||||
        shapes = [
 | 
			
		||||
            (10, 10, 10),
 | 
			
		||||
            (10, 10, 1000),
 | 
			
		||||
            (1000, 1000, 1000),
 | 
			
		||||
        ]
 | 
			
		||||
        all_segments = [[0, 0, 1.0], [0, 0.5, 1.0], [r / 9 for r in range(10)]]
 | 
			
		||||
 | 
			
		||||
        for M, N, K in shapes:
 | 
			
		||||
            for s in all_segments:
 | 
			
		||||
                segments = []
 | 
			
		||||
                for i in range(len(s) - 1):
 | 
			
		||||
                    segments.append([s[i], s[i + 1]])
 | 
			
		||||
                segments = mx.array(segments)
 | 
			
		||||
                segments = mx.minimum(K - 1, (K * segments).astype(mx.uint32))
 | 
			
		||||
                a = mx.random.normal((M, K))
 | 
			
		||||
                b = mx.random.normal((K, N))
 | 
			
		||||
                c1 = segmented_mm_ref(a, b, segments)
 | 
			
		||||
                c2 = mx.segmented_mm(a, b, segments)
 | 
			
		||||
                self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
 | 
			
		||||
 | 
			
		||||
                a = mx.random.normal((K, M))
 | 
			
		||||
                b = mx.random.normal((K, N))
 | 
			
		||||
                c1 = segmented_mm_ref(a.T, b, segments)
 | 
			
		||||
                c2 = mx.segmented_mm(a.T, b, segments)
 | 
			
		||||
                self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
 | 
			
		||||
 | 
			
		||||
                a = mx.random.normal((M, K))
 | 
			
		||||
                b = mx.random.normal((N, K))
 | 
			
		||||
                c1 = segmented_mm_ref(a, b.T, segments)
 | 
			
		||||
                c2 = mx.segmented_mm(a, b.T, segments)
 | 
			
		||||
                self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
 | 
			
		||||
 | 
			
		||||
                a = mx.random.normal((K, M))
 | 
			
		||||
                b = mx.random.normal((N, K))
 | 
			
		||||
                c1 = segmented_mm_ref(a.T, b.T, segments)
 | 
			
		||||
                c2 = mx.segmented_mm(a.T, b.T, segments)
 | 
			
		||||
                self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
 | 
			
		||||
 | 
			
		||||
        with self.assertRaises(ValueError):
 | 
			
		||||
            a = mx.ones((2, 10, 10))
 | 
			
		||||
            s = mx.array([[0, 5], [5, 10]]).astype(mx.uint32)
 | 
			
		||||
            mx.segmented_mm(a, a, s)
 | 
			
		||||
 | 
			
		||||
        a = mx.ones((10, 1000))
 | 
			
		||||
        s = mx.random.randint(0, 16, shape=(1000,))
 | 
			
		||||
        s = mx.zeros(16, dtype=s.dtype).at[s].add(1)
 | 
			
		||||
        s = mx.sort(s)
 | 
			
		||||
        s = mx.cumsum(s)
 | 
			
		||||
        s = mx.concatenate([mx.array([0]), s])
 | 
			
		||||
        s = mx.as_strided(s, (16, 2), (1, 1))
 | 
			
		||||
        s = mx.reshape(s, (2, 2, 4, 2))
 | 
			
		||||
        c = mx.segmented_mm(a, a.T, s)
 | 
			
		||||
        self.assertEqual(c.shape, (2, 2, 4, 10, 10))
 | 
			
		||||
 | 
			
		||||
    def test_gemv_gemm_same_precision(self):
 | 
			
		||||
        mx.random.seed(0)
 | 
			
		||||
        N = 256
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user