mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	Layer norm grad fix donation bug (#941)
* add layer norm grad test * Fix donation bug in layernorm vjp --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							9cbff5ec1d
						
					
				
				
					commit
					110d9b149d
				
			@@ -355,7 +355,14 @@ void LayerNormVJP::eval_gpu(
 | 
			
		||||
    ReductionPlan plan(
 | 
			
		||||
        ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
 | 
			
		||||
    strided_reduce_general_dispatch(
 | 
			
		||||
        g, gb, "sum", plan, {0}, compute_encoder, d, s);
 | 
			
		||||
        g_in_gx ? gx : (g_in_gw ? gw_temp : g),
 | 
			
		||||
        gb,
 | 
			
		||||
        "sum",
 | 
			
		||||
        plan,
 | 
			
		||||
        {0},
 | 
			
		||||
        compute_encoder,
 | 
			
		||||
        d,
 | 
			
		||||
        s);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const int simd_size = 32;
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user