mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Fast Hadamard Transform (#1249)
* Working hadamard for powers of 2 * working for m*2^k * add scale and check contiguity * add size check * clean up * fix test * add grads + vmap * gpu only * skip on linux * test typo * add cpu impl * remove gpu only tests * fix linux build + add is_equivalent
This commit is contained in:
		@@ -4372,6 +4372,35 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
            a (array): Input array or scalar.
 | 
			
		||||
            dtype (Dtype): The data type to change to.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            array: The array with the new type.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "hadamard_transform",
 | 
			
		||||
      &hadamard_transform,
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      "scale"_a = 1.0,
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def hadamard_transform(a: array, float scale = 1.0, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Perform the Walsh-Hadamard transform along the final axis.
 | 
			
		||||
 | 
			
		||||
        Equivalent to:
 | 
			
		||||
        ```python
 | 
			
		||||
        from scipy.linalg import hadamard
 | 
			
		||||
 | 
			
		||||
        y = hadamard(len(x)) @ x
 | 
			
		||||
        ```
 | 
			
		||||
 | 
			
		||||
        Supports sizes `n = m*2^k` where m in (1, 12, 20, 28)
 | 
			
		||||
          and 2^k <= 8192 for FP32 and 2^k <= 16384 for FP16/BF16.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            a (array): Input array or scalar.
 | 
			
		||||
            scale (float): Scale the output by this factor.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            array: The array with the new type.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user