2024-04-09 23:50:36 +08:00
|
|
|
// Copyright © 2023-2024 Apple Inc.
|
2023-12-01 03:12:53 +08:00
|
|
|
|
2024-04-09 23:50:36 +08:00
|
|
|
#include <nanobind/nanobind.h>
|
|
|
|
#include <nanobind/stl/variant.h>
|
2023-11-30 02:42:59 +08:00
|
|
|
|
|
|
|
#include "axpby/axpby.h"
|
|
|
|
|
2024-04-09 23:50:36 +08:00
|
|
|
namespace nb = nanobind;
|
|
|
|
using namespace nb::literals;
|
|
|
|
|
|
|
|
NB_MODULE(_ext, m) {
|
|
|
|
m.doc() = "Sample extension for MLX";
|
2023-11-30 02:42:59 +08:00
|
|
|
|
|
|
|
m.def(
|
|
|
|
"axpby",
|
2024-12-11 23:08:29 +08:00
|
|
|
&my_ext::axpby,
|
2023-11-30 02:42:59 +08:00
|
|
|
"x"_a,
|
|
|
|
"y"_a,
|
|
|
|
"alpha"_a,
|
|
|
|
"beta"_a,
|
2024-04-09 23:50:36 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
R"(
|
2024-01-02 13:08:17 +08:00
|
|
|
Scale and sum two vectors element-wise
|
2023-11-30 02:42:59 +08:00
|
|
|
``z = alpha * x + beta * y``
|
2024-04-09 23:50:36 +08:00
|
|
|
|
2023-11-30 02:42:59 +08:00
|
|
|
Follows numpy style broadcasting between ``x`` and ``y``
|
|
|
|
Inputs are upcasted to floats if needed
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x (array): Input array.
|
|
|
|
y (array): Input array.
|
|
|
|
alpha (float): Scaling factor for ``x``.
|
|
|
|
beta (float): Scaling factor for ``y``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: ``alpha * x + beta * y``
|
2024-04-09 23:50:36 +08:00
|
|
|
)");
|
|
|
|
}
|