Custom primitive + RoPE fat op (#676)

* extensions start

* rope custom op

* fix build

* docs + rope benchmark

* fix test

* Add a Metal kernel for RoPE

* Fix position of traditional

* transform tests

* Move rope computation to float and fix tests

* Fix the test and a typo

* change to fast

* fix no metal build

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-02-14 14:04:25 -08:00
committed by GitHub
parent 1a48713d32
commit ccf1645995
18 changed files with 624 additions and 70 deletions

View File

@@ -17,6 +17,7 @@ void init_random(py::module_&);
void init_fft(py::module_&);
void init_linalg(py::module_&);
void init_constants(py::module_&);
void init_extensions(py::module_&);
PYBIND11_MODULE(core, m) {
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
@@ -33,5 +34,6 @@ PYBIND11_MODULE(core, m) {
init_fft(m);
init_linalg(m);
init_constants(m);
init_extensions(m);
m.attr("__version__") = TOSTRING(_VERSION_);
}