mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 02:28:13 +08:00 
			
		
		
		
	MoE backward improvements (#2335)
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							a4fcc893cd
						
					
				
				
					commit
					4a9b29a875
				
			@@ -549,6 +549,49 @@ class TestQuantized(mlx_tests.MLXTestCase):
 | 
			
		||||
            self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
 | 
			
		||||
            self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
 | 
			
		||||
 | 
			
		||||
    def test_gather_qmm_grad(self):
 | 
			
		||||
        def gather_qmm_ref(x, w, s, b, lhs, rhs, trans, sort):
 | 
			
		||||
            if lhs is not None:
 | 
			
		||||
                x = x[lhs]
 | 
			
		||||
            if rhs is not None:
 | 
			
		||||
                w = w[rhs]
 | 
			
		||||
                s = s[rhs]
 | 
			
		||||
                b = b[rhs]
 | 
			
		||||
            return mx.quantized_matmul(x, w, s, b, transpose=trans)
 | 
			
		||||
 | 
			
		||||
        def gather_qmm(x, w, s, b, lhs, rhs, trans, sort):
 | 
			
		||||
            return mx.gather_qmm(
 | 
			
		||||
                x,
 | 
			
		||||
                w,
 | 
			
		||||
                s,
 | 
			
		||||
                b,
 | 
			
		||||
                transpose=trans,
 | 
			
		||||
                lhs_indices=lhs,
 | 
			
		||||
                rhs_indices=rhs,
 | 
			
		||||
                sorted_indices=sort,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        x = mx.random.normal((16, 1, 256))
 | 
			
		||||
        w, s, b = mx.quantize(mx.random.normal((4, 256, 256)))
 | 
			
		||||
        indices = mx.sort(mx.random.randint(0, 4, shape=(16,)))
 | 
			
		||||
        cotan = mx.random.normal((16, 1, 256))
 | 
			
		||||
 | 
			
		||||
        (o1,), (dx1, ds1, db1) = mx.vjp(
 | 
			
		||||
            lambda x, s, b: gather_qmm_ref(x, w, s, b, None, indices, True, True),
 | 
			
		||||
            [x, s, b],
 | 
			
		||||
            [cotan],
 | 
			
		||||
        )
 | 
			
		||||
        (o2,), (dx2, ds2, db2) = mx.vjp(
 | 
			
		||||
            lambda x, s, b: gather_qmm(x, w, s, b, None, indices, True, True),
 | 
			
		||||
            [x, s, b],
 | 
			
		||||
            [cotan],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(mx.allclose(o1, o2, atol=1e-4))
 | 
			
		||||
        self.assertTrue(mx.allclose(dx1, dx2, atol=1e-4))
 | 
			
		||||
        self.assertTrue(mx.allclose(ds1, ds2, atol=1e-3))
 | 
			
		||||
        self.assertTrue(mx.allclose(db1, db2, atol=1e-3))
 | 
			
		||||
 | 
			
		||||
    def test_vjp_scales_biases(self):
 | 
			
		||||
        mx.random.seed(0)
 | 
			
		||||
        x = mx.random.normal(shape=(2, 2, 512))
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user