mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Put along axis + fixe for partition grad (#1430)
* put along axis, fixes for partition grad * zeros for arg reduce
This commit is contained in:
		@@ -1463,7 +1463,48 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
              operation.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            array: The output array with the specified shape and values.
 | 
			
		||||
            array: The output array.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "put_along_axis",
 | 
			
		||||
      [](const array& a,
 | 
			
		||||
         const array& indices,
 | 
			
		||||
         const array& values,
 | 
			
		||||
         const std::optional<int>& axis,
 | 
			
		||||
         StreamOrDevice s) {
 | 
			
		||||
        if (axis.has_value()) {
 | 
			
		||||
          return put_along_axis(a, indices, values, axis.value(), s);
 | 
			
		||||
        } else {
 | 
			
		||||
          return reshape(
 | 
			
		||||
              put_along_axis(reshape(a, {-1}, s), indices, values, 0, s),
 | 
			
		||||
              a.shape(),
 | 
			
		||||
              s);
 | 
			
		||||
        }
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      "indices"_a,
 | 
			
		||||
      "values"_a,
 | 
			
		||||
      "axis"_a.none(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def put_along_axis(a: array, /, indices: array, values: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Put values along an axis at the specified indices.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            a (array): Destination array.
 | 
			
		||||
            indices (array): Indices array. These should be broadcastable with
 | 
			
		||||
              the input array excluding the `axis` dimension.
 | 
			
		||||
            values (array): Values array. These should be broadcastable with
 | 
			
		||||
              the indices.
 | 
			
		||||
 | 
			
		||||
            axis (int or None): Axis in the destination to put the values to. If
 | 
			
		||||
              ``axis == None`` the destination is flattened prior to the put
 | 
			
		||||
              operation.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            array: The output array.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "full",
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user