mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Adds mx.fast.layer_norm (#870)
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							105d236889
						
					
				
				
					commit
					2225374060
				
			@@ -46,6 +46,42 @@ void init_fast(nb::module_& parent_module) {
 | 
			
		||||
            array: The output array.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
 | 
			
		||||
  m.def(
 | 
			
		||||
      "layer_norm",
 | 
			
		||||
      [](const array& x,
 | 
			
		||||
         const std::optional<array>& weight,
 | 
			
		||||
         const std::optional<array>& bias,
 | 
			
		||||
         float eps,
 | 
			
		||||
         const StreamOrDevice& s /* = {} */) {
 | 
			
		||||
        return fast::layer_norm(x, weight, bias, eps, s);
 | 
			
		||||
      },
 | 
			
		||||
      "x"_a,
 | 
			
		||||
      "weight"_a.none(),
 | 
			
		||||
      "bias"_a.none(),
 | 
			
		||||
      "eps"_a,
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def layer_norm(x: array, weight: Optional[array], bias: Optional[array], eps: float, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Layer normalization.
 | 
			
		||||
 | 
			
		||||
        The normalization is with respect to the last axis of the input ``x``.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            x (array): Input array.
 | 
			
		||||
            weight (array, optional): A multiplicative weight to scale the result by.
 | 
			
		||||
              The ``weight`` should be one-dimensional with the same size
 | 
			
		||||
              as the last axis of ``x``. If set to ``None`` then no scaling happens.
 | 
			
		||||
            bias (array, optional): An additive offset to be added to the result.
 | 
			
		||||
              The ``bias`` should be one-dimensional with the same size
 | 
			
		||||
              as the last axis of ``x``. If set to ``None`` then no translation happens.
 | 
			
		||||
            eps (float): A small additive constant for numerical stability.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            array: The output array.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
 | 
			
		||||
  m.def(
 | 
			
		||||
      "rope",
 | 
			
		||||
      [](const array& a,
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user