Angelos Katharopoulos
62f297b51d
Sdpa fix ( #1558 )
2024-11-02 21:25:46 -07:00
Awni Hannun
57c6aa7188
fix multi output leak ( #1548 )
2024-10-31 09:32:01 -07:00
Awni Hannun
4f72c66911
improvements to scatter / gather ( #1541 )
2024-10-30 19:30:54 -07:00
Alex Barron
048fabdabd
Fix vmap constant output size ( #1524 )
...
* use inputs to determine output size
* remove noop vmap tests
2024-10-30 16:16:53 -07:00
Awni Hannun
d2ff04a4f2
fix format ( #1539 )
2024-10-28 18:29:14 -07:00
Awni Hannun
0eb56d5be0
Wired ( #1510 )
...
* expose residency sets as wire/unwire
* returns wired size
* fix
* runtime support check
* fix os check
* fix test
* fix no metal build
* docs
* nit
* nits in docs
* nits
2024-10-25 09:35:33 -07:00
Venkata Naga Aditya Datta Chivukula
430ffef58a
[Feature] Added Sparse Initialization ( #1498 )
...
Co-authored-by: Saanidhyavats <saanidhyavats@gmail.com>
2024-10-24 12:31:24 -07:00
Alex Barron
3d17077187
Add mx.array.__format__ ( #1521 )
...
* add __format__
* actually test something
* fix
2024-10-24 11:11:39 -07:00
Angelos Katharopoulos
c9b41d460f
Working 64-bit scans ( #1506 )
2024-10-24 11:05:46 -07:00
Kashif Rasul
3ddc07e936
Eigenvalues and eigenvectors ( #1334 )
...
* initial eigvalsh
* add compute_vectors
* add compute_vectors_
* return a pair
* add eigh to return only eigenvectors
* fixed typo
* merge merge Eighvalsh and Eigh into a single primitive
* use the same primate with the flag
* fix primatives
* use MULTI
* fix eval_gpu
* fix decleration
* rename EighPrimitive to Eigh
* tests
* tests
* fix rebase and format
* cleanup lapack
* format
* add cblas.h
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-10-22 12:18:48 -07:00
Awni Hannun
c26208f67d
Remove Hazard tracking with Fences ( #1509 )
...
* remove hazard tracking
* with fence map
* no hazard tracking with fences
* nits
* fix fence retain
* cleanup
* fix quantized rebase
2024-10-21 19:33:32 -07:00
Alex Barron
d15fa13daf
Batched Quantized Matmul + Fast Small QMV ( #1503 )
...
* add fast qmv for small dims
* fix test
* batched cpu
* add batched template param
* refactor metal quantized.cpp
2024-10-21 16:23:17 -07:00
Awni Hannun
92d7cb71f8
Fix compile ( #1501 )
...
* fix compile
* fix space
2024-10-18 11:06:40 -07:00
Awni Hannun
3f86399922
Real and Imag ( #1490 )
...
* real and imag
* fix
* fix
2024-10-15 16:23:15 -07:00
Awni Hannun
0ab8e099e8
Fix cpu segfault ( #1488 )
...
* fix cpu segfault
* nit in tests
2024-10-14 16:17:03 -07:00
Awni Hannun
881615b072
Faster metal compiled kernels + some fixes ( #1486 )
...
* bump mac tests to use py39
* work per thread for compiled kernels
* fixe for large arrays
* fix
2024-10-14 12:45:38 -07:00
Awni Hannun
bf6ec92216
Make the GPU device more thread safe ( #1478 )
...
* gpu stream safety
* comment
* fix
2024-10-12 17:49:15 -07:00
Awni Hannun
e1c9600da3
Add mx.random.permutation
( #1471 )
...
* random permutation
* comment
2024-10-08 19:42:19 -07:00
Awni Hannun
1fa0d20a30
consistently handle all -inf in softmax ( #1470 )
2024-10-08 09:54:02 -07:00
Awni Hannun
3274c6a087
Fix array is_available race cases ( #1468 )
2024-10-07 19:13:50 -07:00
Angelos Katharopoulos
9b12093739
Add the roll op ( #1455 )
2024-10-07 17:21:42 -07:00
Awni Hannun
f374b6ca4d
Bump nanobind to 2.2 ( #1461 )
...
* bump nanobind
* extension version for tests
2024-10-07 16:52:40 -07:00
Awni Hannun
0070e1db40
Fix deep recursion with siblings ( #1462 )
...
* fix recursion with siblings
* fix
* add test
* increase tol
2024-10-07 06:15:33 -07:00
Awni Hannun
e4534dac17
Conv grad with groups + bugfix ( #1449 )
...
* fix bug in flipped conv with groups, start of grad for groups
* fix
* fix
* fix + test
2024-10-06 07:08:53 -07:00
Awni Hannun
1bdc038bf9
fix argpartition + faster {arg} sorts / partitions ( #1453 )
2024-10-03 14:21:25 -07:00
Lucas Newman
4a64d4bff1
Add support for grouped 1D convolutions to the nn API ( #1444 )
...
* Fix the weight shape for grouped convolutions from the nn API.
* Add tests.
* Pre-commit formatting.
* Add input validation.
* Use integer division instead of casting.
* docs
* nit
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-28 06:41:07 -07:00
Awni Hannun
718aea3f1d
allow take to work with integer index ( #1440 )
2024-09-26 15:58:03 -07:00
Awni Hannun
195b429d99
Put along axis + fixe for partition grad ( #1430 )
...
* put along axis, fixes for partition grad
* zeros for arg reduce
2024-09-23 10:03:38 -07:00
Nripesh Niketan
6af5ca35b2
feat: add cross_product ( #1252 )
...
* feat: add cross_product
* lint
* python binding
* refactor: Improve error message for cross_product function
* refactor: more close to numpy cross product
* refactor: improve error message for cross_product function
* finish
* fix acks
* allow old numpy
* doc
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-17 13:12:43 -07:00
Angelos Katharopoulos
914409fef9
Data parallel helper ( #1407 )
2024-09-16 18:17:21 -07:00
Awni Hannun
d6492b0163
fix clip ( #1415 )
2024-09-14 16:09:09 -07:00
Awni Hannun
8b30acd7eb
fix module attribute set, reset, set ( #1403 )
2024-09-11 16:30:42 -07:00
Awni Hannun
3ae6aabe9f
throw for certain cases of non captured inputs in compile ( #1401 )
2024-09-09 14:54:31 -07:00
Max-Heinrich Laves
efeb9c0f02
Transposed Convolution ( #1245 )
...
* initial implementation for conv_transpose
ran pre-commit
implemented conv_transpose
updated conv_general docstring
updated conv_general docstring
updated code comments
removed commented run_conv_checks
updated acknowledgments
added missing entry to ops.rst
added op to nn.layers
resolved merge conflicts
* removed ConvolutionTranspose primitive as suggested by reviewer
removed ConvolutionTranspose primitive as suggested by reviewer
* remove transpose flag, add another test
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-06 19:52:38 -07:00
Awni Hannun
ba3e913c7a
Simplifications for MLX C ( #1396 )
...
* simplifications for MLX C
* use vectors instead of map
* update examples
2024-09-06 19:16:50 -07:00
Awni Hannun
7cca1727af
Fix slice data size ( #1394 )
...
* fix slice data size and add tests
* fix contiguous flag
* simplify stride and perform copy for non-contiguous arrays
* fix cpu
* comment
2024-09-04 19:10:43 -07:00
Bhargav Yagnik
11371fe251
Test to prevent bugs like #1386 ( #1391 )
...
* updated test_array for missing ops
* formatting changes
2024-09-04 17:24:30 -07:00
Angelos Katharopoulos
969337345f
Fix reduce edge case ( #1389 )
2024-09-01 21:37:51 -07:00
Awni Hannun
0d302cd25b
Fix compiel with byte sized constants ( #1381 )
2024-08-30 17:24:35 -07:00
Aditya Dhulipala
e6b223df5f
Pinv ( #875 )
2024-08-27 23:06:12 -07:00
Angelos Katharopoulos
cdb59faea6
Adds send/recv ops in distributed ( #1366 )
2024-08-26 23:01:37 -07:00
Alex Barron
1d94ac3f90
Add optional headers to `mx.fast.metal_kernel
` ( #1358 )
2024-08-26 21:45:45 -07:00
Alex Barron
d1183821a7
int() and float() for mx.array ( #1360 )
2024-08-25 20:41:44 -07:00
Angelos Katharopoulos
8081df79be
Fix boolean all reduce bug ( #1355 )
2024-08-24 10:09:32 -07:00
Angelos Katharopoulos
b57a52813b
Further reduction tuning ( #1349 )
...
* More reduction tuning
* Forgotten pdb
* Small column long row specialization
2024-08-23 10:35:25 -07:00
Alex Barron
da8deb2b62
fix bug with multiple attributes ( #1348 )
...
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-08-23 10:06:15 -07:00
Awni Hannun
98b6ce3460
Refactor reductions and fix scatter atomics for large sizes ( #1300 )
...
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-08-22 16:03:31 -07:00
Alex Barron
0fd2a1f4b0
Custom Metal Kernels from Python ( #1325 )
...
* start
* simple kernels working
* restructure
* inverse example working
* docs + fixes
* missing file
* fix imports
* address comments
* add docs + fix test
* Review comments + refactor to a single function
* update docs
* remove hashing
* fix contig bug in test
* back to a class
* trailing whitespace
* fix tests
* match c++ and python apis
* add link + make args kw_only
2024-08-22 13:46:29 -07:00
Awni Hannun
d40e76809f
Fix rope ( #1340 )
...
* add test
* fix rope
* fix test
2024-08-20 17:37:52 -07:00
Awni Hannun
bb1b76d9dc
RoPE with frequencies as optional input ( #1337 )
...
* start rope with freq input
* rope with frequencies
* nits
* fix bug
* fix bug + test
* cleanup
* optional base
2024-08-19 18:30:50 -07:00
Awni Hannun
ae5b5cabfd
Fix optimizer reloading from checkpoint ( #1329 )
...
* fix optimizer reloading from checkpoint
* comment
2024-08-15 07:33:23 -07:00
Alex Barron
99bb7d3a58
GPU mx.sign for complex64 ( #1326 )
2024-08-14 07:54:53 -07:00
Awni Hannun
eaaea02010
Add isfinite
( #1318 )
...
* isfinite
* remove reduce test since fix is not complete
2024-08-13 14:49:28 -07:00
Bhargav Yagnik
a098bc92e0
Fix: Preserve input dtype in Dropout layer output ( #1323 )
...
* Fix: Preserve input dtype in Dropout layer output
- Modified Dropout implementation to ensure that the output dtype matches the input dtype.
- This resolves the issue #1321
* Update test cases in test_nn.py
- Revised test cases to align with updated dropout code
- Fixed assertion method: replaced self.assertTrue with self.assertEqual for accurate comparisons in test_nn.py -> test_rope, test_alibi and test_dropout,
* updated dropout.py
2024-08-13 11:54:21 -07:00
Brian Keene
19fb69e2ed
Add memory_efficient_threshold kwarg to sdpa kernel ( #1319 )
...
Allows opt-in to memory efficient GPU shader at proscribed sequence
length. Otherwise, utilizes aggregate MLX primitives for best latency.
2024-08-12 12:57:09 -07:00
Alex Barron
32668a7317
CPU mx.linalg.cholesky_inverse and mx.linalg.tri_inv ( #1307 )
...
* add cholesky inv + tri inv
* always run tri_inv on cpu
* consistent naming
2024-08-08 15:18:02 -07:00
Angelos Katharopoulos
780c197f95
Fix test tolerance and patch bump ( #1315 )
2024-08-08 14:51:09 -07:00
Alex Barron
635ccd9e25
Add "edge" mode to mx.pad ( #1309 )
...
* Add edge padding mode
* fix pad in pooling
* string arg instead of enum
2024-08-06 11:23:10 -07:00
nicolov
8c9f0278b9
Add vmap to scatter ( #1200 )
...
* Add vmap to scatter
* updates
* vmap updates + a few more tests
* bug fix
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-08-05 20:12:27 -07:00
Awni Hannun
58d0e199e1
add bfloat conv for windograd ( #1306 )
...
* add bfloat conv for windograd
* accumulate in fp32
* accumulate in fp32
* accumulate in bf16
2024-08-05 15:51:13 -07:00
Awni Hannun
10b5835501
fix creating array from bf16 tensors in jax / torch ( #1305 )
2024-08-01 16:20:51 -07:00
Alex Barron
c52d1600f0
Fused Affine Quantize/Dequantize ops ( #1282 )
...
* Add fast affine dequantize
* add full quantize kernel
* fused kernel with scale/bias computation
* fix docstring
* fix no jit error
* fix test
* test fix
* reduce fast api to only affine_quantize
2024-07-29 15:11:38 -07:00
Atakan Tekparmak
6e06e3a904
feat: Added "tanh" option to GELU approximation ( #1268 )
2024-07-28 09:07:56 +02:00
Awni Hannun
7b456fd2c0
Array api ( #1289 )
...
* some updates for numpy 2.0 and array api
* some updates for numpy 2.0 and array api
* fix array api doc
2024-07-26 10:40:49 -07:00
Anton Belov
5029894662
[Issue #1187 ] Add nan_to_num function initial attempt ( #1247 )
...
* initial attempt, working with wrong types
* not compiling; mx.float16 and mx.bfloat16 tests added
* fix nan to num
* nit
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-07-25 09:57:37 -07:00
Awni Hannun
baf9fa5f42
Einsum ( #1269 )
...
* einsum initial
* fix comma break
* sum axis was wrong
* small cleanups
* python binding
* changed bindings to resemble numpy
* remove todo comment
* comment changes
* add count of operands/inputs
* fail fast if operands list is empty
* ignore comma if no output
* einsum path matching numpy
* getting somewhere with path
* remove print
* it passes the first test
* moved einsum tests to seperate file
* seperated einsum path
* moved einsum naive
* remove space from equation
* fast fail if no operands passed
* update tests and remove printf
* small cleanup
* some more cleanups
* removed python helper file
* ack
* utilize std for finding min in vector
* duplicate def
* remove the tuple as it was unreadable
* moved einsum_naive back to ops
* remaining isn't needed
* avoid creating another set
* cleanup
* greedy path, start of naive einsum
* more einsum
* fix some bugs
* some more fixes, tests pass
* benchmark
* some simplify
* fix einsum and test
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
* add a bunch more tests and fix a bunch more bugs
* some docs nits
---------
Co-authored-by: dc-dc-dc <dgcruz983@gmail.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-07-25 09:36:44 -07:00
Jagrit Digani
7f914365fd
Fix GPU sort for large arrays ( #1285 )
...
* Fix GPU sort for large arrays
2024-07-24 14:37:10 -07:00
Paul Paczuski
ebd7135b50
Improve stability of BCE loss calculation for input probabilities close to or exactly 0 or 1 ( #1280 )
...
* Improve stability of BCE loss calculation
* Standardize comment
* Apply formatting with black via pre-commit
* Add usage recommendation to docstring
* Update python/mlx/nn/losses.py
---------
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-07-24 08:38:22 -07:00
fgranqvist
50eff6a10a
Implement sampling from laplace distribution. ( #1279 )
2024-07-24 15:15:37 +02:00
Alex Barron
c34a5ae7f7
Fix bfloat16 Hadamard ( #1283 )
...
* fix bfloat16 hadamard
* add scale
* review comments
---------
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-07-23 14:54:43 -07:00
Awni Hannun
e2aa6ec8ae
some fixes ( #1281 )
2024-07-23 11:49:05 -07:00
Tim Gymnich
6307d166eb
Fix overflow / underflow handling for expm1f ( #1278 )
...
* Fix overflow / underflow handling for expm1f
* update tests
2024-07-23 07:29:06 -07:00
Awni Hannun
1fba87b0df
Fix leak with multi-output primitives ( #1274 )
...
* fix leak with multi-output primitives
* hopefully an actual fix
2024-07-23 06:34:18 -07:00
Angelos Katharopoulos
5c1fa64fb0
Custom transforms ( #1246 )
2024-07-10 18:00:01 -07:00
Alex Barron
a3c287354f
Fast Hadamard Transform ( #1249 )
...
* Working hadamard for powers of 2
* working for m*2^k
* add scale and check contiguity
* add size check
* clean up
* fix test
* add grads + vmap
* gpu only
* skip on linux
* test typo
* add cpu impl
* remove gpu only tests
* fix linux build + add is_equivalent
2024-07-09 20:39:01 -07:00
Alex Barron
bdb36c9a63
add zero vjps for bitwise ops and gather w.r.t. index ( #1256 )
2024-07-07 21:34:59 -07:00
Alex Barron
2615660e62
Fix strided sort bug ( #1236 )
...
* Use output strides in sort kernel
* fix zero strides bug
2024-06-26 14:32:11 -07:00
Awni Hannun
5b0af4cdb1
fix donation condition for compilation ( #1237 )
2024-06-26 09:04:05 -07:00
Jagrit Digani
2d6cd47713
Masked gemv ( #1211 )
2024-06-14 09:52:26 -07:00
Awni Hannun
df964132fb
fix scatter + test ( #1202 )
...
* fix scatter + test
* fix test warnings
* fix metal validation
2024-06-11 14:35:12 -07:00
Alex Barron
27d70c7d9d
Feature complete Metal FFT ( #1102 )
...
* feature complete metal fft
* fix contiguity bug
* jit fft
* simplify rader/bluestein constant computation
* remove kernel/utils.h dep
* remove bf16.h dep
* format
---------
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-06 12:57:25 -07:00
Angelos Katharopoulos
0163a8e57a
Add docs for the distributed namespace ( #1184 )
2024-06-06 11:37:00 -07:00
Awni Hannun
496315fe1d
Fix scan ( #1188 )
...
* fix scan
* improve grid size
* fix cpu cummax
2024-06-05 14:21:58 -07:00
Angelos Katharopoulos
0fe6895893
Fix the hard-shrink test ( #1185 )
2024-06-04 16:22:56 -07:00
Nikhil Mehta
0b7d71fd2f
Add softmin, hardshrink, hardtanh ( #1180 )
...
---------
Co-authored-by: Nikhil Mehta <nikmehta@tesla.com>
2024-06-04 15:48:18 -07:00
Awni Hannun
83b11bc58d
Fix Metal API validation for empty concat ( #1183 )
2024-06-04 13:17:08 -07:00
Awni Hannun
ea9090bbc4
Add view op ( #1179 )
...
* add view primitive
* nit
* fix view
2024-06-04 08:05:27 -07:00
Brian Keene
1865299a30
Metal shaders for memory efficient self attention on large sequences ( #964 )
...
* Metal shaders for efficient self attention on large sequences
Updated fast attention: GEMM-ified with Steel primitives
Uses flash attention 1 for scale correction
* more compiler silencing
* Address rebase issues
* Templatize kernel instantiation, revise cpu bindings
* Safer writes to output
* Permit batch size > 1
* Numerical fixes for sdpa self attention
* Re-enable test, remove unused variable
* add benchmarking script
* Disable sdpa prior to perf tuning, and simplify tests for per-patch CI
2024-06-03 09:16:19 -07:00
K Venkat Ramnan
ab977109db
feat: Added dlpack device ( #1165 )
...
* feat: Added dlpack device
* feat: Added device_id to dlpack device
* feat: Added device_id to dlpack device
* doc: updated conversion docs
* doc: updated numpy.rst dlpack information
* doc: updated numpy.rst dlpack information
* Update docs/src/usage/numpy.rst
* Update docs/src/usage/numpy.rst
---------
Co-authored-by: Venkat Ramnan Kalyanakumar <venkatramnankalyanakumar@Venkats-MacBook-Air.local>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-05-31 12:29:01 -07:00
Awni Hannun
fd1c08137b
stable cumprod grad at 0 ( #1167 )
2024-05-31 12:28:42 -07:00
Jagrit Digani
76b6cece46
Fix multi-block sort stride management ( #1169 )
...
* Fix multi-block sort stride management
* Add seed to tests
2024-05-31 11:10:54 -07:00
Jagrit Digani
9f0df51f8d
Fix matvec vector stride bug ( #1168 )
2024-05-29 12:18:28 -07:00
Awni Hannun
e7a2a3dcd1
Fix a couple bugs ( #1161 )
...
* fix jit reduce for RMS norm
* make strides a single buffer
* better eval error message
* fix compiling with inf and bf16
* fix cpu compile with bf16
2024-05-28 15:18:18 -07:00
Awni Hannun
a87ef5bfc1
fix broadcast bug in bitwise ops ( #1157 )
2024-05-24 11:44:40 -07:00
Awni Hannun
7e26fd8032
Option to JIT steel gemm / conv ( #1139 )
2024-05-23 18:07:34 -07:00
Jagrit Digani
eab2685c67
Float mask update ( #1152 )
...
* Float mask update
* Update CPU impl
2024-05-23 17:20:44 -07:00
Angelos Katharopoulos
50dfb664db
Comms ( #1097 )
...
* Start the communications branch using MPI
* Add ops and primitives
* Add python bindings for distributed
2024-05-23 17:04:02 -07:00
Rifur13
9401507336
Add groups to 2-D convolutions ( #1129 )
...
* Added groups to 2-D convolutions. Only implemented for **some** specializations.
Also fixed 1D grouped convs with different kernel strides and added more tests.
* fix channels condition
2024-05-22 20:01:44 -07:00
Awni Hannun
eb8321d863
list based indexing ( #1150 )
2024-05-22 15:52:05 -07:00
Abe Leininger
79ef49b2c2
add mx.trace ( #1143 ) ( #1147 )
...
* working c++ trace implementation
* updated throw + added overloads
* added python binding for trace function
* pre-commit reformatting
* add trace to docs
* resolve comments
* remove to_stream call
2024-05-22 15:50:27 -07:00
Awni Hannun
d568c7ee36
Rename block sparse ( #1149 )
...
* block_sparse_mm to gather_mm
* rename
* nit
* nit
2024-05-22 07:48:34 -07:00
jlwitthuhn
7e5674d8be
Treate 'minimum' differently in cosine decay ( #1138 )
2024-05-20 08:00:48 -07:00
Awni Hannun
fb71a82ada
Fix copy bug with many dims ( #1137 )
2024-05-17 21:10:03 -07:00
Luca Arnaboldi
b3ec792380
Implemented Cholesky on CPU ( #1119 )
2024-05-17 12:31:59 -07:00
Awni Hannun
81dd33af66
allow conversion to dlpack ( #1120 )
2024-05-16 16:11:37 -07:00
Angelos Katharopoulos
e78a6518fa
Block sparse qmm ( #1124 )
2024-05-16 15:24:14 -07:00
Awni Hannun
631dfbe673
fix scatter index bug ( #1122 )
2024-05-14 15:04:58 -07:00
Max-Heinrich Laves
ff4223904d
Conv3d ( #993 )
...
* added conv3d
added conv3d
implemented explicit_gemm_conv_ND_cpu and bounds checks for slow_conv_3D
* incorporated reviewer comments
* fixed test
* reduced tensor shapes in test for conv3d
* Reviewer suggestion
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Reviewer suggestion
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Reviewer suggestion
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Reviewer suggestion
2024-05-11 06:15:02 -07:00
Alex Barron
2e158cf6d0
Add conjugate operator ( #1100 )
...
* cpu and gpu impl
* add mx.conj and array.conj()
---------
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-05-10 07:22:20 -07:00
Awni Hannun
b21242faf1
Allow unary ops to accept array like ( #1093 )
2024-05-09 09:36:02 -07:00
Rahul Yedida
cc05a281c4
Added ArcTan2 operation ( #1079 )
...
* Added ArcTan2 operation
* Cleanup, bug fixes from code review
* Minor cleanup, fixed Linux tests
2024-05-08 08:35:15 -07:00
Awni Hannun
9814a2ae12
fix conversion to array ( #1070 )
2024-05-06 16:02:49 -07:00
Awni Hannun
21623156a3
Reset peak memory ( #1074 )
...
* reset peak memory
* fix linux
* nits in docs
2024-05-03 17:12:51 -07:00
Nripesh Niketan
79c859e2e0
feat: implement clip_grad_norm
( #1043 )
...
* feat: implement `clip_grad_norm`
* pre-commit
* Add test for clip_grad_norm function in test_optimizers.py
* small fixes
* fix
* lint
* Update tree_reduce
* Update python/mlx/utils.py
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/mlx/utils.py
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/mlx/utils.py
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/mlx/utils.py
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/mlx/utils.py
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/mlx/utils.py
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Refactor clip_grad_norm function to include documentation and improve readability
* format docstring
* Add acknowlegements
* text wrap
* pre-commit
* nits in docs
---------
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-05-03 09:07:02 -07:00
Jagrit Digani
f390957685
Block sparse mm ( #1058 )
2024-05-02 14:03:58 -07:00
Angelos Katharopoulos
17f57df797
Improvements in the quantizer and dequantization kernel ( #1061 )
2024-05-01 18:19:11 -07:00
Awni Hannun
7f7b9662ea
Fix leak for multi-output primitives which are never detached ( #1059 )
...
* fix multi output leak
* ignore arrays that will be detached
* add some comments
* stray print
2024-05-01 07:31:45 -07:00
Angelos Katharopoulos
8db7161c94
Bug fix in quantize ( #1054 )
2024-04-29 20:55:04 -07:00
Awni Hannun
09f1777896
fix slice update indexing ( #1053 )
2024-04-29 12:17:40 -07:00
Jacket
490c0c4fdc
[Fix] expand axes for dimension with integer indices in mlx_slice_update ( #1035 )
...
* Not sure if this is correct
* Format
* Edit tests
* Add negative test
* Format
* add one more test
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-29 07:57:28 -07:00
Rifur13
c4a471c99d
Add groups to Conv1d ( #948 )
...
* Add conv1d grouped convs on CPU
* Add GPU support
* Parallelize inside metal kernel
* clenaup
* Update mlx/ops.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* New unfold kernel + remove unused code
* Remove copy and refactor
* Update vjp and reuse steel gemm
* Fixed groups on cpu
* Fix metal validation
---------
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-04-27 06:24:57 -07:00
Awni Hannun
86f495985b
Add bitwise ops ( #1037 )
...
* bitwise ops
* fix tests
2024-04-26 22:03:42 -07:00
Awni Hannun
771575d27b
Expose function to clear memory cache ( #1032 )
...
* expose function to clear memory cache
* fix linux build
* fix metal tests
2024-04-24 16:48:51 -07:00
Angelos Katharopoulos
ec8578d41a
Fix quantization of all 0s ( #1028 )
2024-04-24 00:40:42 -07:00
Aneesh Shetty
d0dbfe0b97
Adds radians and degrees ( #1011 )
2024-04-22 11:17:49 -07:00
Awni Hannun
3d405fb3b1
Add synchronize function ( #1006 )
...
* add synchronize function
* fix linux
* fix linux
* fix and fix docs
* fix test
* try synchronize in stream destroy
* synchronize works for both cpu and gpu
2024-04-22 08:25:46 -07:00
Angelos Katharopoulos
84d61d27aa
Make sure 0 is represented in the quantization ( #1016 )
2024-04-19 19:47:26 -07:00
Angelos Katharopoulos
ef5f7d1aea
Fix buffer protocol buffer size designation ( #1010 )
2024-04-19 06:06:13 -07:00
Jagrit Digani
85c8a91a27
Fix mask broadcasting bug and add relevant test ( #1003 )
2024-04-17 17:33:48 -07:00
Awni Hannun
8a0677d56d
Shared events for synchronization + async eval ( #998 )
...
* more async eval
* fix rebase
* try correct async eval
* fix async
* more tests for async eval
* use shared events for synchronization
* comment + cleanup
* with autorelease pool
* fix no metal build
* fix compile
* fix patch
* don't eval if asyn evale'd
* don't use is_evaled
* comments
* more multi stream tests
* try and cleanup use of is_evaled
* use a status flag
2024-04-17 06:16:02 -07:00
Jagrit Digani
b18468bf81
Masked mm ( #978 )
...
* Add block masked matmul op and primitive
2024-04-16 14:45:39 -07:00
Shiyu
107ba2891a
gelu tanh approx ( #989 )
...
* gelu tanh approx
* gelu tanh approx
* replace gelu approx with tanh approach
* fix comments
* fix comment
2024-04-15 19:49:00 -07:00
Awni Hannun
cd9e184529
Quantize embedding ( #994 )
...
* quantize embedding
* rename as_linear + comment
* consistency in docs
* fix test
2024-04-15 16:42:10 -07:00
Alex Barron
2e7c02d5cd
Metal FFT for powers of 2 up to 2048 ( #915 )
...
* add Metal FFT for powers of 2
* skip GPU test on linux
* fix contiguity bug
* address comments
* Update mlx/backend/metal/fft.cpp
* Update mlx/backend/metal/fft.cpp
* fix bug in synch
---------
Co-authored-by: Alex Barron <abarron22@apple.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-11 21:40:06 -07:00
Awni Hannun
ae18326533
No copy command encoder ( #986 )
...
* no copy command encoder
* up layer norm test tolerances
2024-04-11 21:15:36 -07:00
Shiyu
061cf9a4ce
Upsample with bicubic interpolation ( #967 )
2024-04-10 15:47:22 -07:00
Awni Hannun
99abb9eff4
Async eval ( #972 )
2024-04-09 18:34:00 -07:00
Luca Arnaboldi
fffe072028
Implementation of mlx.random.multivariate_normal ( #502 ) ( #877 )
...
* Implementation of mlx.random.multivariate_normal (#502 )
* Update python/src/random.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/src/random.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/src/random.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Updated typo in docstring
* Restricted multivariate_normal to float32
* Generic mean and variance shapes
* Review edits
* Update mlx/random.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/src/random.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/src/random.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/src/random.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Test for ndim of mean and cov
* nits
* smaller size for test
* fix broadcasted sampling
---------
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-09 13:50:12 -07:00
Abe Leininger
a1a31eed27
Add mx.meshgrid ( #961 )
2024-04-09 11:43:08 -07:00
Awni Hannun
42afe27e12
std and expm1 ( #973 )
...
* std and expm1
* actually add expm1
* fix linux
* fix vjp
* relax tol for linux test
* Add it to the compilable primitives
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-04-08 14:26:01 -07:00
Awni Hannun
76e63212ff
Enable bfloat scan ( #974 )
...
* enable bfloat scan
* fix tests
2024-04-08 12:29:19 -07:00
Awni Hannun
039da779d1
No quant reshape ( #957 )
...
* precise option on cpu
* remove print
* remove reshape in quant matmul
* no quant reshape
2024-04-04 11:52:12 -07:00
Awni Hannun
d88d2124b5
segfaut layer norm grad ( #955 )
2024-04-04 10:59:15 -07:00
Awni Hannun
e142aaf8a1
Option for precise softmax ( #953 )
...
* precise softmax
* Add an equivalency check
* Make the threadgroup memory definition fixed
* precise cpu softmax
* precise option on cpu
* remove print
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-04-04 08:32:35 -07:00
AmirHossein_Razlighi
0caf35f4b8
Better exceptions in case of invalid operations on mlx.core.array
( #910 ) ( #926 )
...
* Nicer exceptions for ops on non-arrays
2024-04-02 21:11:24 -07:00
Angelos Katharopoulos
3fc993f82d
Properly handle negative axes in python vmap ( #944 )
2024-04-02 18:07:23 -07:00
Angelos Katharopoulos
1a87dc5ea8
Fix compile fusion for multi-output edge cases ( #950 )
...
* Fix compile fusion for multi-output edge cases
* Add a test for multi-output compile
2024-04-02 08:42:31 -07:00
Awni Hannun
2427fa171e
Fix cpu compile ( #934 )
...
* fix one cpu bug, test for another
* format hooks
* simplify contiguity check for cpu compile
* fix
* add back donation
* comment
2024-04-01 17:37:12 -07:00
Jagrit Digani
639e06e1f3
Indexing bug fix ( #947 )
...
* Fix axes accounting
* Add tests
2024-04-01 12:18:50 -07:00
Angelos Katharopoulos
02fedbf1da
Fix array initialization from list ( #942 )
...
* Fix array initialization from list
* Change the error message in the test
2024-04-01 06:27:52 -07:00