* fast rmsnorm
* no rms gpu
* kernel
* fix shared mem
* looped rms and donation in softmax
* Make the squaring in float32 to avoid underflow
* Fix the default StreamOrDevice for rope and rms_norm in fast
* nits
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
* mostly builds
* most tests pass
* fix circle build
* add back buffer protocol
* includes
* fix for py38
* limit to cpu device
* include
* fix stubs
* move signatures for docs
* stubgen + docs fix
* doc for compiled function, comments
* Enable collapsing batch dims in gemm
* Update gemm to only make copies when neither of the last 2 axes are contiguous
* Update addmm to support gemv shapes
* Update addmm to support irregular batch strides
* Update tests
* Fast Inference SDPA op
Implements metal shaders for:
o = mx.fast_inference_sdpa(queries, keys, values, scale, mask)
Supports fp16, fp32 dtypes; assumes d_k = 128.
Generic op support / prompt encoding supported via mlx primitives.
Metal implementation is for the inference use case only.
Majority of performance benefits appears to results from GQA & reduced
bandwidth requirements; there is approximate performance parity for the
MHA use case (from some measurements on M3 Max).
* Flush shared memory to zero before unprotected reads for (scores @ values)
* Move to fast:: namespace, address reviewer comments
... also attempt to revert formatter auto-change for files not relevant
to this change
* Shared memory flush to top of kernel
* Resolve compiler warnings
* Update python/src/fast.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/src/fast.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/src/fast.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/src/fast.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update docstring per PR feedback
* Softmax in higher precision, ...
* route to fallback for more use cases - batch size > 1, head_dim other
than 128, etc.
* Address linux build failure
* Address other reviewer comments
* Remove extraneous eval_cpu function per review
---------
Co-authored-by: Atila Orhon <64497909+atiorh@users.noreply.github.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: atila <atiorh@icloud.com>
* refactor tree utils
* fix compile + tree code refactor
* Add an extra test
* add a few missing activations to docs
* hash structure
* Encode the full argument structure
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
* Add linear warmup to schedules for use with existing schedules
* Changed parameters for simplicity of most common case (0 initial value)
* Added ScheduleJoiner and updated documentation
* ScheduleJoiner -> join_schedules (ala optax #)
* black compliance
* Different evaluation of schedules
* nits
---------
Co-authored-by: Awni Hannun <awni@apple.com>
* Fix case for step=inf in arange and add inf check for start/stop
* Add test cases for arange
* Update ops.cpp to include climits header
* Fix arange
* Fix formatting
* Refactor
* Add missing include
* shapeless compilation for some graphs
* update compile benchmark
* default compile a few activations
* buffer donation
* bugfix
* shapeless fix
* update tests to work for cpu and gpu fusion
* test kwargs
* add kwargs to compile
* Recompile when python arguments change
* no compile for tanh
* some constant tests
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
* Add a few LR schedulers
* Move parents's constructor call to the top
* Fix docstring
* refactor optimizers into two files
* add docs
* nit
* Fix Callable type annotation for python 3.8
---------
Co-authored-by: Awni Hannun <awni@apple.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
* extensions start
* rope custom op
* fix build
* docs + rope benchmark
* fix test
* Add a Metal kernel for RoPE
* Fix position of traditional
* transform tests
* Move rope computation to float and fix tests
* Fix the test and a typo
* change to fast
* fix no metal build
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
* Simple kernel generation
* Remove the generate kernel from graph_utils
* fix multi-output with compile
* fuse with stopgrad
* v1 input, output capture in compile
* cleanup tree update with visitor update
* nit
* remove todo
* state for model, optional explicit init and more pure optimizer steps
* move learning rate to state
* add lr to opt state, some fixes in capture
* fix optim
* update tuple of containers as well
* fix stream for compiled output
* rng state for compile
* nit
* updates and comments
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
* CI update
* Skip large binary test for now
* Upgrade pip
* Add proper env variable skipping
* Update the CI
* Fix workflow name
* Set the low memory flag for the tests
* Change build process
* Add pip upgrade
* Use a venv
* Add a missing env activate
* Add setuptools
* Add twine upload back
* Re-enable automatic release builds