mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Add init python binding for tunable matmul
This commit is contained in:
parent
2ed2e0e3da
commit
e21143961c
@ -16,6 +16,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/internal/tuner/ops.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
||||||
|
|
||||||
if(MLX_BUILD_CPU)
|
if(MLX_BUILD_CPU)
|
||||||
|
@ -23,6 +23,7 @@ nanobind_add_module(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/internal.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
|
||||||
|
|
||||||
if(NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)
|
if(NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)
|
||||||
|
59
python/src/internal.cpp
Normal file
59
python/src/internal.cpp
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <nanobind/nanobind.h>
|
||||||
|
#include <nanobind/stl/optional.h>
|
||||||
|
#include <nanobind/stl/pair.h>
|
||||||
|
#include <nanobind/stl/string.h>
|
||||||
|
#include <nanobind/stl/tuple.h>
|
||||||
|
#include <nanobind/stl/unordered_map.h>
|
||||||
|
#include <nanobind/stl/variant.h>
|
||||||
|
#include <nanobind/stl/vector.h>
|
||||||
|
|
||||||
|
#include "python/src/utils.h"
|
||||||
|
|
||||||
|
#include "mlx/internal/tuner/ops.h"
|
||||||
|
#include "mlx/ops.h"
|
||||||
|
|
||||||
|
namespace nb = nanobind;
|
||||||
|
using namespace nb::literals;
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
void init_internal(nb::module_& parent_module) {
|
||||||
|
auto m = parent_module.def_submodule(
|
||||||
|
"internal", "mlx.core.internal: internal operations");
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"tunable_matmul",
|
||||||
|
&internal::tunable_matmul,
|
||||||
|
nb::arg(),
|
||||||
|
nb::arg(),
|
||||||
|
nb::arg(),
|
||||||
|
nb::kw_only(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def tunable_matmul(a: array, b: array, tparams: dict[str, int], /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
Matrix multiplication.
|
||||||
|
|
||||||
|
Perform the (possibly batched) matrix multiplication of two arrays. This function supports
|
||||||
|
broadcasting for arrays with more than two dimensions.
|
||||||
|
|
||||||
|
- If the first array is 1-D then a 1 is prepended to its shape to make it
|
||||||
|
a matrix. Similarly if the second array is 1-D then a 1 is appended to its
|
||||||
|
shape to make it a matrix. In either case the singleton dimension is removed
|
||||||
|
from the result.
|
||||||
|
- A batched matrix multiplication is performed if the arrays have more than
|
||||||
|
2 dimensions. The matrix dimensions for the matrix product are the last
|
||||||
|
two dimensions of each input.
|
||||||
|
- All but the last two dimensions of each input are broadcast with one another using
|
||||||
|
standard numpy-style broadcasting semantics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array or scalar.
|
||||||
|
b (array): Input array or scalar.
|
||||||
|
tparams (dict[str, int]): Matmul tunable parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The matrix product of ``a`` and ``b``.
|
||||||
|
)pbdoc");
|
||||||
|
}
|
@ -19,6 +19,7 @@ void init_linalg(nb::module_&);
|
|||||||
void init_constants(nb::module_&);
|
void init_constants(nb::module_&);
|
||||||
void init_fast(nb::module_&);
|
void init_fast(nb::module_&);
|
||||||
void init_distributed(nb::module_&);
|
void init_distributed(nb::module_&);
|
||||||
|
void init_internal(nb::module_&);
|
||||||
|
|
||||||
NB_MODULE(core, m) {
|
NB_MODULE(core, m) {
|
||||||
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
|
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
|
||||||
@ -39,6 +40,7 @@ NB_MODULE(core, m) {
|
|||||||
init_constants(m);
|
init_constants(m);
|
||||||
init_fast(m);
|
init_fast(m);
|
||||||
init_distributed(m);
|
init_distributed(m);
|
||||||
|
init_internal(m);
|
||||||
|
|
||||||
m.attr("__version__") = TOSTRING(_VERSION_);
|
m.attr("__version__") = TOSTRING(_VERSION_);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user