* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate
numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)).
Closes#285.
* nits in docs
* unify type category checking
* nits in docs
* nits in docs
* more docs nits
* fix callable type
---------
Co-authored-by: Awni Hannun <awni@apple.com>
* some small overhead improvements
* use result_type in rms_norm
* remove release force
* fix + use non-vector version
* revert compile change
* fix ops
* a little more overhead
* a little more cleanup and overhead
* fast rmsnorm
* no rms gpu
* kernel
* fix shared mem
* looped rms and donation in softmax
* Make the squaring in float32 to avoid underflow
* Fix the default StreamOrDevice for rope and rms_norm in fast
* nits
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
* refactor tree utils
* fix compile + tree code refactor
* Add an extra test
* add a few missing activations to docs
* hash structure
* Encode the full argument structure
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
* Add linear warmup to schedules for use with existing schedules
* Changed parameters for simplicity of most common case (0 initial value)
* Added ScheduleJoiner and updated documentation
* ScheduleJoiner -> join_schedules (ala optax #)
* black compliance
* Different evaluation of schedules
* nits
---------
Co-authored-by: Awni Hannun <awni@apple.com>
* shapeless compilation for some graphs
* update compile benchmark
* default compile a few activations
* buffer donation
* bugfix
* shapeless fix
* update tests to work for cpu and gpu fusion
* test kwargs
* add kwargs to compile
* Recompile when python arguments change
* no compile for tanh
* some constant tests
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
* Add a few LR schedulers
* Move parents's constructor call to the top
* Fix docstring
* refactor optimizers into two files
* add docs
* nit
* Fix Callable type annotation for python 3.8
---------
Co-authored-by: Awni Hannun <awni@apple.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
* extensions start
* rope custom op
* fix build
* docs + rope benchmark
* fix test
* Add a Metal kernel for RoPE
* Fix position of traditional
* transform tests
* Move rope computation to float and fix tests
* Fix the test and a typo
* change to fast
* fix no metal build
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
* Simple kernel generation
* Remove the generate kernel from graph_utils
* fix multi-output with compile
* fuse with stopgrad
* v1 input, output capture in compile
* cleanup tree update with visitor update
* nit
* remove todo
* state for model, optional explicit init and more pure optimizer steps
* move learning rate to state
* add lr to opt state, some fixes in capture
* fix optim
* update tuple of containers as well
* fix stream for compiled output
* rng state for compile
* nit
* updates and comments
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
* Add `py.typed` to support PEP-561 (type-hinting)
This adds support for type-hinting information as laid in [PEP-561](https://peps.python.org/pep-0561/).
* add py.typed to MANIFEST.in
* Implement custom_vjp and checkpointing
* Add a dependency management primitive
* Change the eval order to deep branches first
* Add graph depth tracking to the array
* Added adafactor
* Added Adafactor and ran pre-commit
* modified operations
* Added docstrings
* Switched two ops to fix a bug
* added underscore for internal functions and removed the plus sign in the last return statment
* Removed parameter rms from the optimizer state because its not needed
* Added simple MNIST test for Adafactor and temporary training log
* remove test files
* nits in docs
* comment nit
---------
Co-authored-by: Awni Hannun <awni@apple.com>