mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Option for precise softmax (#953)
* precise softmax * Add an equivalency check * Make the threadgroup memory definition fixed * precise cpu softmax * precise option on cpu * remove print --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
		@@ -2430,12 +2430,13 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "softmax",
 | 
			
		||||
      [](const array& a, const IntOrVec& axis, StreamOrDevice s) {
 | 
			
		||||
        return softmax(a, get_reduce_axes(axis, a.ndim()), s);
 | 
			
		||||
      [](const array& a, const IntOrVec& axis, bool precise, StreamOrDevice s) {
 | 
			
		||||
        return softmax(a, get_reduce_axes(axis, a.ndim()), precise, s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      "axis"_a = nb::none(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "precise"_a = false,
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def softmax(a: array, /, axis: Union[None, int, Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user