diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index c30177966..e62c68098 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -16,6 +16,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/internal/tuner/ops.cpp ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h) if(MLX_BUILD_CPU) diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt index 104ad6d69..b28105455 100644 --- a/python/src/CMakeLists.txt +++ b/python/src/CMakeLists.txt @@ -23,6 +23,7 @@ nanobind_add_module( ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp ${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp ${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/internal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp) if(NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY) diff --git a/python/src/internal.cpp b/python/src/internal.cpp new file mode 100644 index 000000000..baa27ec79 --- /dev/null +++ b/python/src/internal.cpp @@ -0,0 +1,59 @@ +// Copyright © 2023-2024 Apple Inc. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "python/src/utils.h" + +#include "mlx/internal/tuner/ops.h" +#include "mlx/ops.h" + +namespace nb = nanobind; +using namespace nb::literals; +using namespace mlx::core; + +void init_internal(nb::module_& parent_module) { + auto m = parent_module.def_submodule( + "internal", "mlx.core.internal: internal operations"); + + m.def( + "tunable_matmul", + &internal::tunable_matmul, + nb::arg(), + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def tunable_matmul(a: array, b: array, tparams: dict[str, int], /, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Matrix multiplication. + + Perform the (possibly batched) matrix multiplication of two arrays. This function supports + broadcasting for arrays with more than two dimensions. + + - If the first array is 1-D then a 1 is prepended to its shape to make it + a matrix. Similarly if the second array is 1-D then a 1 is appended to its + shape to make it a matrix. In either case the singleton dimension is removed + from the result. + - A batched matrix multiplication is performed if the arrays have more than + 2 dimensions. The matrix dimensions for the matrix product are the last + two dimensions of each input. + - All but the last two dimensions of each input are broadcast with one another using + standard numpy-style broadcasting semantics. + + Args: + a (array): Input array or scalar. + b (array): Input array or scalar. + tparams (dict[str, int]): Matmul tunable parameters + + Returns: + array: The matrix product of ``a`` and ``b``. + )pbdoc"); +} \ No newline at end of file diff --git a/python/src/mlx.cpp b/python/src/mlx.cpp index a261c1f88..4d6d17ff2 100644 --- a/python/src/mlx.cpp +++ b/python/src/mlx.cpp @@ -19,6 +19,7 @@ void init_linalg(nb::module_&); void init_constants(nb::module_&); void init_fast(nb::module_&); void init_distributed(nb::module_&); +void init_internal(nb::module_&); NB_MODULE(core, m) { m.doc() = "mlx: A framework for machine learning on Apple silicon."; @@ -39,6 +40,7 @@ NB_MODULE(core, m) { init_constants(m); init_fast(m); init_distributed(m); + init_internal(m); m.attr("__version__") = TOSTRING(_VERSION_); }