mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
python bindings
This commit is contained in:
parent
8ded7c8d37
commit
3fb3385a1a
@ -3194,4 +3194,26 @@ void init_ops(py::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
result (array): The dequantized version of ``w``
|
result (array): The dequantized version of ``w``
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"tensordot",
|
||||||
|
[](const array& a,
|
||||||
|
const array& b,
|
||||||
|
const IntOrIntVec& dims,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
if (std::holds_alternative<std::monostate>(dims)) {
|
||||||
|
return tensordot(a, b, 2, s);
|
||||||
|
} else if (auto pv = std::get_if<int>(&dims); pv) {
|
||||||
|
return tensordot(a, b, *pv, s);
|
||||||
|
} else {
|
||||||
|
return tensordot(
|
||||||
|
a, b, std::get<std::vector<std::vector<int>>>(dims), s);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"a"_a,
|
||||||
|
"b"_a,
|
||||||
|
py::pos_only(),
|
||||||
|
"dims"_a = 2,
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
"");
|
||||||
}
|
}
|
||||||
|
@ -15,6 +15,8 @@ namespace py = pybind11;
|
|||||||
using namespace mlx::core;
|
using namespace mlx::core;
|
||||||
|
|
||||||
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
|
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
|
||||||
|
using IntOrIntVec =
|
||||||
|
std::variant<std::monostate, int, std::vector<std::vector<int>>>;
|
||||||
using ScalarOrArray = std::
|
using ScalarOrArray = std::
|
||||||
variant<py::bool_, py::int_, py::float_, std::complex<float>, py::object>;
|
variant<py::bool_, py::int_, py::float_, std::complex<float>, py::object>;
|
||||||
static constexpr std::monostate none{};
|
static constexpr std::monostate none{};
|
||||||
|
@ -1547,6 +1547,22 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
expected_3 = np.repeat(data_3, 2, axis=0)
|
expected_3 = np.repeat(data_3, 2, axis=0)
|
||||||
self.assertEqualArray(repeat_3, mx.array(expected_3))
|
self.assertEqualArray(repeat_3, mx.array(expected_3))
|
||||||
|
|
||||||
|
def test_tensordot(self):
|
||||||
|
x = mx.arange(60.0).reshape(3, 4, 5)
|
||||||
|
y = mx.arange(24.0).reshape(4, 3, 2)
|
||||||
|
z = mx.tensordot(x, y, dims=([1, 0], [0, 1]))
|
||||||
|
self.assertEqualArray(z, mx.array(np.tensordot(x, y, axes=([1, 0], [0, 1]))))
|
||||||
|
x = mx.random.normal((3, 4, 5))
|
||||||
|
y = mx.random.normal((4, 5, 6))
|
||||||
|
z = mx.tensordot(x, y, dims=2)
|
||||||
|
self.assertEqualArray(z, mx.array(np.tensordot(x, y, axes=2)))
|
||||||
|
x = mx.random.normal((3, 5, 4, 6))
|
||||||
|
y = mx.random.normal((6, 4, 5, 3))
|
||||||
|
z = mx.tensordot(x, y, dims=([2, 1, 3], [1, 2, 0]))
|
||||||
|
self.assertEqualArray(
|
||||||
|
z, mx.array(np.tensordot(x, y, axes=([2, 1, 3], [1, 2, 0])))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user