Awni Hannun
e4534dac17
Conv grad with groups + bugfix ( #1449 )
...
* fix bug in flipped conv with groups, start of grad for groups
* fix
* fix
* fix + test
2024-10-06 07:08:53 -07:00
Awni Hannun
1bdc038bf9
fix argpartition + faster {arg} sorts / partitions ( #1453 )
2024-10-03 14:21:25 -07:00
Lucas Newman
4a64d4bff1
Add support for grouped 1D convolutions to the nn API ( #1444 )
...
* Fix the weight shape for grouped convolutions from the nn API.
* Add tests.
* Pre-commit formatting.
* Add input validation.
* Use integer division instead of casting.
* docs
* nit
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-28 06:41:07 -07:00
Awni Hannun
718aea3f1d
allow take to work with integer index ( #1440 )
2024-09-26 15:58:03 -07:00
Awni Hannun
afc9c0ec1b
dtype is copy assignable ( #1436 )
2024-09-25 12:07:13 -07:00
Awni Hannun
195b429d99
Put along axis + fixe for partition grad ( #1430 )
...
* put along axis, fixes for partition grad
* zeros for arg reduce
2024-09-23 10:03:38 -07:00
Nripesh Niketan
6af5ca35b2
feat: add cross_product ( #1252 )
...
* feat: add cross_product
* lint
* python binding
* refactor: Improve error message for cross_product function
* refactor: more close to numpy cross product
* refactor: improve error message for cross_product function
* finish
* fix acks
* allow old numpy
* doc
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-17 13:12:43 -07:00
Awni Hannun
4f46e9c997
More fixes for arrays with large sizes ( #1405 )
...
* compile works for big arrays when contiguous
* style
* nits in docs
* a bunch more stuff
* update jit
* update jit
* use constant for shapes and strides and remove elem_to_loc overload
* use kernel instantiation
* docs nits
* update binary and ternary
* comments
2024-09-17 12:46:31 -07:00
Awni Hannun
c6739ba7f3
Faster RNN layers ( #1419 )
...
* faster rnn
* use admm
2024-09-17 06:04:19 -07:00
Angelos Katharopoulos
914409fef9
Data parallel helper ( #1407 )
2024-09-16 18:17:21 -07:00
Awni Hannun
d5ed4d7a71
override class function ( #1418 )
2024-09-16 13:21:04 -07:00
Nripesh Niketan
669c27140d
Chore: add pre-commit hook for cmake ( #1362 )
...
* reset and lint
* format
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-16 12:53:01 -07:00
Awni Hannun
d6492b0163
fix clip ( #1415 )
2024-09-14 16:09:09 -07:00
c0g
bd8396fad8
Fix typo in transformer docs ( #1414 )
2024-09-14 06:05:15 -07:00
Awni Hannun
8b30acd7eb
fix module attribute set, reset, set ( #1403 )
2024-09-11 16:30:42 -07:00
Awni Hannun
02efb310ca
Xcode 160 ( #1384 )
...
* xcode 16.0 with debug tests
* limit nproc for builds
* vmap bug
* assert bug
* run python tests in debug mode
* fix view, bool copies preserve bits'
* actual view fix
2024-09-10 15:15:17 -07:00
Awni Hannun
3ae6aabe9f
throw for certain cases of non captured inputs in compile ( #1401 )
2024-09-09 14:54:31 -07:00
Max-Heinrich Laves
efeb9c0f02
Transposed Convolution ( #1245 )
...
* initial implementation for conv_transpose
ran pre-commit
implemented conv_transpose
updated conv_general docstring
updated conv_general docstring
updated code comments
removed commented run_conv_checks
updated acknowledgments
added missing entry to ops.rst
added op to nn.layers
resolved merge conflicts
* removed ConvolutionTranspose primitive as suggested by reviewer
removed ConvolutionTranspose primitive as suggested by reviewer
* remove transpose flag, add another test
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-06 19:52:38 -07:00
Awni Hannun
ba3e913c7a
Simplifications for MLX C ( #1396 )
...
* simplifications for MLX C
* use vectors instead of map
* update examples
2024-09-06 19:16:50 -07:00
Awni Hannun
7cca1727af
Fix slice data size ( #1394 )
...
* fix slice data size and add tests
* fix contiguous flag
* simplify stride and perform copy for non-contiguous arrays
* fix cpu
* comment
2024-09-04 19:10:43 -07:00
Bhargav Yagnik
11371fe251
Test to prevent bugs like #1386 ( #1391 )
...
* updated test_array for missing ops
* formatting changes
2024-09-04 17:24:30 -07:00
Angelos Katharopoulos
969337345f
Fix reduce edge case ( #1389 )
2024-09-01 21:37:51 -07:00
Awni Hannun
9592766939
add std as method ( #1387 )
...
* add std as method
* add std as method
2024-09-01 19:49:16 -07:00
Awni Hannun
0d302cd25b
Fix compiel with byte sized constants ( #1381 )
2024-08-30 17:24:35 -07:00
Awni Hannun
dba2bd1105
Even Even Faster IO ( #1374 )
...
* even more faster io
* make reader pool static
* make python reader thread safe
* one more optimization
2024-08-29 16:05:40 -07:00
Awni Hannun
fcb65a3897
Even Faster I/O ( #1369 )
...
* try multithreading for faster IO
* smaller batch size
* Account for pread returning less than size
* nit
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-08-28 11:49:07 -07:00
Saanidhya
4e22a1dffe
In continuation to PR1243 to solve issue #1240 ( #1365 )
...
* Solves issue #1240
* Correction
* Update python/mlx/utils.py
* Update python/mlx/utils.py
---------
Co-authored-by: Awni Hannun <awni@apple.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-08-28 11:40:41 -07:00
Awni Hannun
291cf40aca
Some fixes to typing ( #1371 )
...
* some fixes to typing
* fix module reference
* comment
2024-08-28 11:16:19 -07:00
Aditya Dhulipala
e6b223df5f
Pinv ( #875 )
2024-08-27 23:06:12 -07:00
Angelos Katharopoulos
cdb59faea6
Adds send/recv ops in distributed ( #1366 )
2024-08-26 23:01:37 -07:00
Alex Barron
1d94ac3f90
Add optional headers to `mx.fast.metal_kernel
` ( #1358 )
2024-08-26 21:45:45 -07:00
Awni Hannun
5f7d19d1f5
MPI ops in GPU stream for faster comms ( #1356 )
2024-08-26 15:12:50 -07:00
Alex Barron
d1183821a7
int() and float() for mx.array ( #1360 )
2024-08-25 20:41:44 -07:00
Angelos Katharopoulos
8081df79be
Fix boolean all reduce bug ( #1355 )
2024-08-24 10:09:32 -07:00
Alex Barron
b96e105244
Add grid_sample
example to metal_kernel
docs ( #1352 )
...
* Add `zero_outputs` and `atomic_outputs` options to `metal_kernel`
* add grid sample to docs
* zero_outputs -> init_value
* add missing header for linux
2024-08-23 18:24:16 -07:00
Awni Hannun
3b4d5484c7
Bump extension MLX version ( #1350 )
...
* Bump extension MLX version
* fix some docs nits
2024-08-23 12:38:34 -07:00
Angelos Katharopoulos
b57a52813b
Further reduction tuning ( #1349 )
...
* More reduction tuning
* Forgotten pdb
* Small column long row specialization
2024-08-23 10:35:25 -07:00
Alex Barron
da8deb2b62
fix bug with multiple attributes ( #1348 )
...
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-08-23 10:06:15 -07:00
Awni Hannun
98b6ce3460
Refactor reductions and fix scatter atomics for large sizes ( #1300 )
...
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-08-22 16:03:31 -07:00
Alex Barron
0fd2a1f4b0
Custom Metal Kernels from Python ( #1325 )
...
* start
* simple kernels working
* restructure
* inverse example working
* docs + fixes
* missing file
* fix imports
* address comments
* add docs + fix test
* Review comments + refactor to a single function
* update docs
* remove hashing
* fix contig bug in test
* back to a class
* trailing whitespace
* fix tests
* match c++ and python apis
* add link + make args kw_only
2024-08-22 13:46:29 -07:00
Awni Hannun
d40e76809f
Fix rope ( #1340 )
...
* add test
* fix rope
* fix test
2024-08-20 17:37:52 -07:00
Awni Hannun
bb1b76d9dc
RoPE with frequencies as optional input ( #1337 )
...
* start rope with freq input
* rope with frequencies
* nits
* fix bug
* fix bug + test
* cleanup
* optional base
2024-08-19 18:30:50 -07:00
Awni Hannun
ae5b5cabfd
Fix optimizer reloading from checkpoint ( #1329 )
...
* fix optimizer reloading from checkpoint
* comment
2024-08-15 07:33:23 -07:00
Alex Barron
99bb7d3a58
GPU mx.sign for complex64 ( #1326 )
2024-08-14 07:54:53 -07:00
Awni Hannun
63ae767232
fix transformer ( #1327 )
2024-08-13 16:04:26 -07:00
Awni Hannun
eaaea02010
Add isfinite
( #1318 )
...
* isfinite
* remove reduce test since fix is not complete
2024-08-13 14:49:28 -07:00
Bhargav Yagnik
a098bc92e0
Fix: Preserve input dtype in Dropout layer output ( #1323 )
...
* Fix: Preserve input dtype in Dropout layer output
- Modified Dropout implementation to ensure that the output dtype matches the input dtype.
- This resolves the issue #1321
* Update test cases in test_nn.py
- Revised test cases to align with updated dropout code
- Fixed assertion method: replaced self.assertTrue with self.assertEqual for accurate comparisons in test_nn.py -> test_rope, test_alibi and test_dropout,
* updated dropout.py
2024-08-13 11:54:21 -07:00
Brian Keene
19fb69e2ed
Add memory_efficient_threshold kwarg to sdpa kernel ( #1319 )
...
Allows opt-in to memory efficient GPU shader at proscribed sequence
length. Otherwise, utilizes aggregate MLX primitives for best latency.
2024-08-12 12:57:09 -07:00
Awni Hannun
9231617eb3
Move to nanobind v2 ( #1316 )
2024-08-08 17:17:46 -07:00
Alex Barron
32668a7317
CPU mx.linalg.cholesky_inverse and mx.linalg.tri_inv ( #1307 )
...
* add cholesky inv + tri inv
* always run tri_inv on cpu
* consistent naming
2024-08-08 15:18:02 -07:00