mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 01:18:12 +08:00
Extensions (#962)
* start to fix extensions * mostly fixed extensions * fix extension build * couple more nits
This commit is contained in:
@@ -1,31 +1,31 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
|
||||
#include "axpby/axpby.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
PYBIND11_MODULE(mlx_sample_extensions, m) {
|
||||
m.doc() = "Sample C++ and metal extensions for MLX";
|
||||
NB_MODULE(_ext, m) {
|
||||
m.doc() = "Sample extension for MLX";
|
||||
|
||||
m.def(
|
||||
"axpby",
|
||||
&axpby,
|
||||
"x"_a,
|
||||
"y"_a,
|
||||
py::pos_only(),
|
||||
"alpha"_a,
|
||||
"beta"_a,
|
||||
py::kw_only(),
|
||||
"stream"_a = py::none(),
|
||||
R"pbdoc(
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
R"(
|
||||
Scale and sum two vectors element-wise
|
||||
``z = alpha * x + beta * y``
|
||||
|
||||
|
||||
Follows numpy style broadcasting between ``x`` and ``y``
|
||||
Inputs are upcasted to floats if needed
|
||||
|
||||
@@ -37,5 +37,5 @@ PYBIND11_MODULE(mlx_sample_extensions, m) {
|
||||
|
||||
Returns:
|
||||
array: ``alpha * x + beta * y``
|
||||
)pbdoc");
|
||||
}
|
||||
)");
|
||||
}
|
||||
|
Reference in New Issue
Block a user