Awni Hannun
7b7e2352cd
fix malloc or wait deadlock ( #1976 )
2025-03-20 16:48:43 -07:00
Awni Hannun
7aea5b1895
Allow dynamic ops per buffer based on dispatches and memory ( #1864 )
...
* Allow dynamic ops per buffer based on dispatches and memory
* add initial arch values
2025-02-13 19:18:22 -08:00
Awni Hannun
1156c84e86
Refactor common into cpu specific and truly common ( #1817 )
...
* refactor
* fix extension example
* fix no-cpu
2025-02-03 15:58:02 -08:00
Awni Hannun
40c62c1321
Use int64 stride everywhere ( #1671 )
...
* use int64 stride everywhere
* fix ext
* fix ext
* more shape + cleanup
* one more
* few more
2024-12-09 11:09:02 -08:00
Awni Hannun
6931f84412
fix dispatch threads for a few kernels ( #1594 )
2024-11-18 08:35:25 -08:00
Awni Hannun
9f0d5c12fc
Fully wrap the command encoder ( #1572 )
...
* fully wrap the command encoder
* use consistent style + fix extensions
2024-11-08 11:50:21 -08:00
Jagrit Digani
960e3f0f05
Gemm update ( #1518 )
2024-10-30 19:30:28 -07:00
Awni Hannun
c26208f67d
Remove Hazard tracking with Fences ( #1509 )
...
* remove hazard tracking
* with fence map
* no hazard tracking with fences
* nits
* fix fence retain
* cleanup
* fix quantized rebase
2024-10-21 19:33:32 -07:00
Awni Hannun
e4534dac17
Conv grad with groups + bugfix ( #1449 )
...
* fix bug in flipped conv with groups, start of grad for groups
* fix
* fix
* fix + test
2024-10-06 07:08:53 -07:00
Angelos Katharopoulos
d878015228
Fix normalization check_input ( #1452 )
2024-10-03 13:26:56 -07:00
Awni Hannun
e7e59c6f05
Fix copying scalars by adding fill_gpu ( #1402 )
...
* fix copying scalars by adding fill_gpu
* Another copy scalar changed to fill
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com >
2024-09-09 15:54:08 -07:00
Awni Hannun
30bbea2f08
Add gemv masked to JIT plus some fixes ( #1310 )
...
* add gemv masked to JIT plus some fixes
* some cleanup
* add utils
* fix
* fix 2
* more cleaning
* fix
* remove unused mps matmul support
* one more nit
* revert
2024-08-07 13:38:07 -07:00
Jagrit Digani
2d6cd47713
Masked gemv ( #1211 )
2024-06-14 09:52:26 -07:00
Jagrit Digani
9f0df51f8d
Fix matvec vector stride bug ( #1168 )
2024-05-29 12:18:28 -07:00
Awni Hannun
7e26fd8032
Option to JIT steel gemm / conv ( #1139 )
2024-05-23 18:07:34 -07:00
Jagrit Digani
eab2685c67
Float mask update ( #1152 )
...
* Float mask update
* Update CPU impl
2024-05-23 17:20:44 -07:00
Awni Hannun
d568c7ee36
Rename block sparse ( #1149 )
...
* block_sparse_mm to gather_mm
* rename
* nit
* nit
2024-05-22 07:48:34 -07:00
Jagrit Digani
358e1fd6ab
Fused GEMM ( #1123 )
...
* Basic gemm working
* Update addmm
* Clear out steel_gemm and steel_addmm kernels
* Fuse and clear out gather gemm
* Update objc releases
2024-05-15 10:30:41 -07:00
Awni Hannun
06375e6605
Split encoders in non-concurrent context with a max ops per encoder ( #1085 )
...
* split encoders
* fix race
2024-05-09 16:21:02 -07:00
Jagrit Digani
f390957685
Block sparse mm ( #1058 )
2024-05-02 14:03:58 -07:00
Rifur13
c4a471c99d
Add groups to Conv1d ( #948 )
...
* Add conv1d grouped convs on CPU
* Add GPU support
* Parallelize inside metal kernel
* clenaup
* Update mlx/ops.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com >
* New unfold kernel + remove unused code
* Remove copy and refactor
* Update vjp and reuse steel gemm
* Fixed groups on cpu
* Fix metal validation
---------
Co-authored-by: Awni Hannun <awni.hannun@gmail.com >
2024-04-27 06:24:57 -07:00
Jagrit Digani
85c8a91a27
Fix mask broadcasting bug and add relevant test ( #1003 )
2024-04-17 17:33:48 -07:00
Jagrit Digani
b18468bf81
Masked mm ( #978 )
...
* Add block masked matmul op and primitive
2024-04-16 14:45:39 -07:00
Nripesh Niketan
ffff671273
Update pre-commit hooks ( #984 )
2024-04-11 07:27:53 -07:00
Awni Hannun
12d4507ee3
Explicit barriers with concurrent dispatch ( #977 )
2024-04-10 21:45:31 -07:00
Daniel Strobusch
479051ce1c
add numeric type hierarchy and issubdtype as well as a set_dtype meth… ( #427 )
...
* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate
numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42) ).
Closes #285 .
* nits in docs
* unify type category checking
* nits in docs
* nits in docs
* more docs nits
* fix callable type
---------
Co-authored-by: Awni Hannun <awni@apple.com >
2024-03-25 12:32:59 -07:00
Jagrit Digani
cec8661113
Add a SliceUpdate op and primitive ( #850 )
...
* Enable copy to work with int64 strides
* Fix uniform buffer indices or copy kernel arguments
* Update utils.h
* Remove manual unrolling of elem to loc loop
* GPU copy updated to handle negative strides
* Add slice update primitive
2024-03-20 10:39:25 -07:00
Jagrit Digani
5ad133f8bb
No copy gems ( #801 )
...
* Enable collapsing batch dims in gemm
* Update gemm to only make copies when neither of the last 2 axes are contiguous
* Update addmm to support gemv shapes
* Update addmm to support irregular batch strides
* Update tests
2024-03-12 13:13:41 -07:00
Awni Hannun
f512b905c7
Minimum xcode / sdk ( #800 )
...
* minimum xcode /sdk
* try multiple xcode versions in CI
* update python
* metal validation for python tests
2024-03-07 08:19:43 -08:00
Jagrit Digani
776c3d226d
Convolution update ( #651 )
...
* Init steel conv and update Conv primitive
* Update slow CPU implementation to support flipping and input dilation winograd conv routing
Co-authored-by: Awni Hannun <awni@apple.com >
2024-02-28 20:11:16 -08:00
Angelos Katharopoulos
0de5988f92
Custom VJP and checkpointing ( #541 )
...
* Implement custom_vjp and checkpointing
* Add a dependency management primitive
* Change the eval order to deep branches first
* Add graph depth tracking to the array
2024-01-30 16:04:45 -08:00
Jagrit Digani
78102a47ad
Update GEMM ( #424 )
...
* Organize and collect metal subroutine templates and elements in `metal/kernels/steel/`
* Update gemm elements for better performance
* Add split-K specialization for gemm
* Add `addmm` primitive, op and bindings for fused matmul and bias addition
* Update tests and benchmarks as needed
2024-01-17 12:42:39 -08:00
Josh Soref
44c1ce5e6a
Spelling ( #342 )
...
* spelling: accumulates
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: across
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: additional
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: against
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: among
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: array
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: at least
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: available
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: axes
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: basically
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: bfloat
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: bounds
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: broadcast
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: buffer
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: class
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: coefficients
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: collision
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: combinations
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: committing
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: computation
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: consider
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: constructing
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: conversions
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: correctly
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: corresponding
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: declaration
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: default
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: dependency
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: destination
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: destructor
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: dimensions
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: divided
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: element-wise
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: elements
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: endianness
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: equivalent
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: explicitly
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: github
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: indices
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: irregularly
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: memory
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: metallib
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: negative
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: notable
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: optional
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: otherwise
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: overridden
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: partially
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: partition
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: perform
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: perturbations
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: positively
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: primitive
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: repeat
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: repeats
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: respect
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: respectively
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: result
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: rounding
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: separate
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: skipping
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: structure
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: the
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: transpose
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: unnecessary
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: unneeded
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: unsupported
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
---------
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
2024-01-01 21:08:17 -08:00
Jagrit Digani
d518b3b6a5
Fix gemv broadcasting bug ( #6 )
...
* Fix broadcasting bug in gemv
* Add relevant tests in test_blas.py
2023-12-05 14:15:43 -08:00
Awni Hannun
46a39e5b1f
copyright + ack
2023-11-30 11:12:53 -08:00
Awni Hannun
8ca7f9e8e9
awni's commit files
2023-11-29 10:30:41 -08:00