mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	allow take to work with integer index (#1440)
This commit is contained in:
		@@ -1398,13 +1398,15 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
  m.def(
 | 
			
		||||
      "take",
 | 
			
		||||
      [](const array& a,
 | 
			
		||||
         const array& indices,
 | 
			
		||||
         const std::variant<int, array>& indices,
 | 
			
		||||
         const std::optional<int>& axis,
 | 
			
		||||
         StreamOrDevice s) {
 | 
			
		||||
        if (axis.has_value()) {
 | 
			
		||||
          return take(a, indices, axis.value(), s);
 | 
			
		||||
        if (auto pv = std::get_if<int>(&indices); pv) {
 | 
			
		||||
          return axis ? take(a, *pv, axis.value(), s) : take(a, *pv, s);
 | 
			
		||||
        } else {
 | 
			
		||||
          return take(a, indices, s);
 | 
			
		||||
          auto indices_ = std::get<array>(indices);
 | 
			
		||||
          return axis ? take(a, indices_, axis.value(), s)
 | 
			
		||||
                      : take(a, indices_, s);
 | 
			
		||||
        }
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
@@ -1413,7 +1415,7 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def take(a: array, /, indices: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
          "def take(a: array, /, indices: Union[int, array], axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Take elements along an axis.
 | 
			
		||||
 | 
			
		||||
@@ -1425,7 +1427,7 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            a (array): Input array.
 | 
			
		||||
            indices (array): Input array with integral type.
 | 
			
		||||
            indices (int or array): Integer index or input array with integral type.
 | 
			
		||||
            axis (int, optional): Axis along which to perform the take. If unspecified
 | 
			
		||||
              the array is treated as a flattened 1-D vector.
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user