mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00

* Remove "using namespace mlx::core" in benchmarks/examples * Fix building example extension * A missing one in comment * Fix building on M chips
40 lines
880 B
C++
40 lines
880 B
C++
// Copyright © 2023-2024 Apple Inc.
|
|
|
|
#include <nanobind/nanobind.h>
|
|
#include <nanobind/stl/variant.h>
|
|
|
|
#include "axpby/axpby.h"
|
|
|
|
namespace nb = nanobind;
|
|
using namespace nb::literals;
|
|
|
|
NB_MODULE(_ext, m) {
|
|
m.doc() = "Sample extension for MLX";
|
|
|
|
m.def(
|
|
"axpby",
|
|
&my_ext::axpby,
|
|
"x"_a,
|
|
"y"_a,
|
|
"alpha"_a,
|
|
"beta"_a,
|
|
nb::kw_only(),
|
|
"stream"_a = nb::none(),
|
|
R"(
|
|
Scale and sum two vectors element-wise
|
|
``z = alpha * x + beta * y``
|
|
|
|
Follows numpy style broadcasting between ``x`` and ``y``
|
|
Inputs are upcasted to floats if needed
|
|
|
|
Args:
|
|
x (array): Input array.
|
|
y (array): Input array.
|
|
alpha (float): Scaling factor for ``x``.
|
|
beta (float): Scaling factor for ``y``.
|
|
|
|
Returns:
|
|
array: ``alpha * x + beta * y``
|
|
)");
|
|
}
|