mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Allow unary ops to accept array like (#1093)
This commit is contained in:
		@@ -158,7 +158,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "abs",
 | 
			
		||||
      &mlx::core::abs,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return mlx::core::abs(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -175,7 +177,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "sign",
 | 
			
		||||
      &sign,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return sign(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -192,7 +196,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "negative",
 | 
			
		||||
      &negative,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return negative(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -600,7 +606,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "square",
 | 
			
		||||
      &square,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return square(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -617,7 +625,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "sqrt",
 | 
			
		||||
      &mlx::core::sqrt,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return mlx::core::sqrt(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -634,7 +644,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "rsqrt",
 | 
			
		||||
      &rsqrt,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return rsqrt(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -651,7 +663,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "reciprocal",
 | 
			
		||||
      &reciprocal,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return reciprocal(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -757,7 +771,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "exp",
 | 
			
		||||
      &mlx::core::exp,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return mlx::core::exp(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -774,7 +790,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "expm1",
 | 
			
		||||
      &mlx::core::expm1,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return mlx::core::expm1(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -793,7 +811,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "erf",
 | 
			
		||||
      &mlx::core::erf,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return mlx::core::erf(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -813,7 +833,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "erfinv",
 | 
			
		||||
      &mlx::core::erfinv,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return mlx::core::erfinv(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -830,7 +852,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "sin",
 | 
			
		||||
      &mlx::core::sin,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return mlx::core::sin(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -847,7 +871,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "cos",
 | 
			
		||||
      &mlx::core::cos,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return mlx::core::cos(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -864,7 +890,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "tan",
 | 
			
		||||
      &mlx::core::tan,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return mlx::core::tan(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -881,7 +909,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "arcsin",
 | 
			
		||||
      &mlx::core::arcsin,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return mlx::core::arcsin(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -898,7 +928,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "arccos",
 | 
			
		||||
      &mlx::core::arccos,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return mlx::core::arccos(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -915,7 +947,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "arctan",
 | 
			
		||||
      &mlx::core::arctan,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return mlx::core::arctan(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -951,7 +985,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "sinh",
 | 
			
		||||
      &mlx::core::sinh,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return mlx::core::sinh(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -968,7 +1004,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "cosh",
 | 
			
		||||
      &mlx::core::cosh,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return mlx::core::cosh(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -985,7 +1023,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "tanh",
 | 
			
		||||
      &mlx::core::tanh,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return mlx::core::tanh(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -1002,7 +1042,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "arcsinh",
 | 
			
		||||
      &mlx::core::arcsinh,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return mlx::core::arcsinh(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -1019,7 +1061,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "arccosh",
 | 
			
		||||
      &mlx::core::arccosh,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return mlx::core::arccosh(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -1036,7 +1080,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "arctanh",
 | 
			
		||||
      &mlx::core::arctanh,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return mlx::core::arctanh(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -1053,7 +1099,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "degrees",
 | 
			
		||||
      &mlx::core::degrees,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return degrees(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -1070,7 +1118,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
    )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "radians",
 | 
			
		||||
      &mlx::core::radians,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return mlx::core::radians(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -1087,7 +1137,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
    )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "log",
 | 
			
		||||
      &mlx::core::log,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return mlx::core::log(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -1104,7 +1156,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "log2",
 | 
			
		||||
      &mlx::core::log2,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return mlx::core::log2(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -1121,7 +1175,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "log10",
 | 
			
		||||
      &mlx::core::log10,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return mlx::core::log10(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -1138,7 +1194,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "log1p",
 | 
			
		||||
      &mlx::core::log1p,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return mlx::core::log1p(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -1176,7 +1234,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "sigmoid",
 | 
			
		||||
      &sigmoid,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return sigmoid(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -1837,7 +1897,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "floor",
 | 
			
		||||
      &mlx::core::floor,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return mlx::core::floor(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -1854,7 +1916,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "ceil",
 | 
			
		||||
      &mlx::core::ceil,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return mlx::core::ceil(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -1871,7 +1935,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "isnan",
 | 
			
		||||
      &mlx::core::isnan,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return mlx::core::isnan(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -1888,7 +1954,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "isinf",
 | 
			
		||||
      &mlx::core::isinf,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return mlx::core::isinf(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -1905,7 +1973,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "isposinf",
 | 
			
		||||
      &isposinf,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return isposinf(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -1923,7 +1993,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "isneginf",
 | 
			
		||||
      &isneginf,
 | 
			
		||||
      [](const ScalarOrArray& a, StreamOrDevice s) {
 | 
			
		||||
        return isneginf(to_array(a), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
@@ -3409,8 +3481,8 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "round",
 | 
			
		||||
      [](const array& a, int decimals, StreamOrDevice s) {
 | 
			
		||||
        return round(a, decimals, s);
 | 
			
		||||
      [](const ScalarOrArray& a, int decimals, StreamOrDevice s) {
 | 
			
		||||
        return round(to_array(a), decimals, s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      "decimals"_a = 0,
 | 
			
		||||
@@ -3750,15 +3822,15 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Matrix multiplication with matrix-level gather.
 | 
			
		||||
 | 
			
		||||
        Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays. 
 | 
			
		||||
        Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays.
 | 
			
		||||
        This operation is more efficient than explicitly applying a :func:``take`` followed by a :func:``matmul``.
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices along the batch dimensions (i.e. all but the last two dimensions) of ``a`` and ``b`` respectively.
 | 
			
		||||
 | 
			
		||||
        For ``a`` with shape ``(A1, A2, ..., AS, M, K)``, 
 | 
			
		||||
        For ``a`` with shape ``(A1, A2, ..., AS, M, K)``,
 | 
			
		||||
        ``lhs_indices`` contains indices from the range ``[0, A1 * A2 * ... * AS)``
 | 
			
		||||
 | 
			
		||||
        For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, 
 | 
			
		||||
        For ``b`` with shape ``(B1, B2, ..., BS, M, K)``,
 | 
			
		||||
        ``rhs_indices`` contains indices from the range ``[0, B1 * B2 * ... * BS)``
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user