// Copyright © 2023-2024 Apple Inc. #include #include #include "axpby/axpby.h" namespace nb = nanobind; using namespace nb::literals; NB_MODULE(_ext, m) { m.doc() = "Sample extension for MLX"; m.def( "axpby", &my_ext::axpby, "x"_a, "y"_a, "alpha"_a, "beta"_a, nb::kw_only(), "stream"_a = nb::none(), R"( Scale and sum two vectors element-wise ``z = alpha * x + beta * y`` Follows numpy style broadcasting between ``x`` and ``y`` Inputs are upcasted to floats if needed Args: x (array): Input array. y (array): Input array. alpha (float): Scaling factor for ``x``. beta (float): Scaling factor for ``y``. Returns: array: ``alpha * x + beta * y`` )"); }