Commit Graph

636 Commits

Author SHA1 Message Date
Awni Hannun
d6492b0163
fix clip (#1415) 2024-09-14 16:09:09 -07:00
Awni Hannun
b3f52c9fbe
ensure io/comm streams are active before eval (#1412) 2024-09-14 06:17:36 -07:00
Angelos Katharopoulos
881f09b2e2
Allow querying the allocator for the buffer size (#1404) 2024-09-11 21:02:16 -07:00
Awni Hannun
02efb310ca
Xcode 160 (#1384)
* xcode 16.0 with debug tests

* limit nproc for builds

* vmap bug

* assert bug

* run python tests in debug mode

* fix view, bool copies preserve bits'

* actual view fix
2024-09-10 15:15:17 -07:00
Awni Hannun
e7e59c6f05
Fix copying scalars by adding fill_gpu (#1402)
* fix copying scalars by adding fill_gpu

* Another copy scalar changed to fill

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-09-09 15:54:08 -07:00
Awni Hannun
3ae6aabe9f
throw for certain cases of non captured inputs in compile (#1401) 2024-09-09 14:54:31 -07:00
xnorai
dc627dcb5e
Replace the use of result_of_t with invoke_result_t (#1397)
* Fix C++20 incompatibility

* Fix C++20 incompatibility
2024-09-06 19:52:57 -07:00
Max-Heinrich Laves
efeb9c0f02
Transposed Convolution (#1245)
* initial implementation for conv_transpose

ran pre-commit

implemented conv_transpose

updated conv_general docstring

updated conv_general docstring

updated code comments

removed commented run_conv_checks

updated acknowledgments

added missing entry to ops.rst

added op to nn.layers

resolved merge conflicts

* removed ConvolutionTranspose primitive as suggested by reviewer

removed ConvolutionTranspose primitive as suggested by reviewer

* remove transpose flag, add another test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-06 19:52:38 -07:00
Awni Hannun
ba3e913c7a
Simplifications for MLX C (#1396)
* simplifications for MLX C

* use vectors instead of map

* update examples
2024-09-06 19:16:50 -07:00
Awni Hannun
7cca1727af
Fix slice data size (#1394)
* fix slice data size and add tests

* fix contiguous flag

* simplify stride and perform copy for non-contiguous arrays

* fix cpu

* comment
2024-09-04 19:10:43 -07:00
Awni Hannun
41c603d48a
fix jit reduce (#1395) 2024-09-04 14:03:10 -07:00
Angelos Katharopoulos
969337345f
Fix reduce edge case (#1389) 2024-09-01 21:37:51 -07:00
Angelos Katharopoulos
58dca7d846
Fix copy in the sort primitive (#1383) 2024-08-31 08:32:14 -07:00
Awni Hannun
0d302cd25b
Fix compiel with byte sized constants (#1381) 2024-08-30 17:24:35 -07:00
Alex Barron
da691257ec
Fix overflow in quantize/dequantize (#1379)
* add 2d indices to prevent overflow

* use nthreads not out size
2024-08-30 13:32:41 -07:00
Awni Hannun
dba2bd1105
Even Even Faster IO (#1374)
* even more faster io

* make reader pool static

* make python reader thread safe

* one more optimization
2024-08-29 16:05:40 -07:00
Alex Barron
28be4de7c2
Fix JIT reductions (#1373) 2024-08-28 16:39:11 -07:00
Awni Hannun
a6c3b38fba
Async load (#1372)
* async load

* async load
2024-08-28 14:21:55 -07:00
Awni Hannun
fcb65a3897
Even Faster I/O (#1369)
* try multithreading for faster IO

* smaller batch size

* Account for pread returning less than size

* nit

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-08-28 11:49:07 -07:00
Jeethu Rao
bd47e1f066
Fix neon_fast_exp and add more softmax tests (#1367) 2024-08-27 23:42:42 -07:00
Aditya Dhulipala
e6b223df5f
Pinv (#875) 2024-08-27 23:06:12 -07:00
Angelos Katharopoulos
e64349bbdd
Make eval just wait if all arrays are scheduled (#1368) 2024-08-27 17:01:22 -07:00
Angelos Katharopoulos
cdb59faea6
Adds send/recv ops in distributed (#1366) 2024-08-26 23:01:37 -07:00
Alex Barron
1d94ac3f90
Add optional headers to `mx.fast.metal_kernel` (#1358) 2024-08-26 21:45:45 -07:00
Awni Hannun
5f7d19d1f5
MPI ops in GPU stream for faster comms (#1356) 2024-08-26 15:12:50 -07:00
Awni Hannun
2fdf9eb535
Fix ternary for large arrays (#1359)
* fix ternary for large arrays

* fix
2024-08-26 11:22:27 -07:00
Awni Hannun
860d3a50d7
fix extension metal library finding (#1361) 2024-08-26 09:18:50 -07:00
Angelos Katharopoulos
8081df79be
Fix boolean all reduce bug (#1355) 2024-08-24 10:09:32 -07:00
Nripesh Niketan
64bec4fad7
Chore: update pre-commit hooks (#1353)
* Chore: update pre-commit refs

* run pre-commit
2024-08-24 06:46:36 -07:00
Alex Barron
b96e105244
Add grid_sample example to metal_kernel docs (#1352)
* Add `zero_outputs` and `atomic_outputs` options to `metal_kernel`

* add grid sample to docs

* zero_outputs -> init_value

* add missing header for linux
2024-08-23 18:24:16 -07:00
Angelos Katharopoulos
b57a52813b
Further reduction tuning (#1349)
* More reduction tuning
* Forgotten pdb
* Small column long row specialization
2024-08-23 10:35:25 -07:00
Alex Barron
da8deb2b62
fix bug with multiple attributes (#1348)
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-08-23 10:06:15 -07:00
Awni Hannun
98b6ce3460
Refactor reductions and fix scatter atomics for large sizes (#1300)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-08-22 16:03:31 -07:00
Alex Barron
0fd2a1f4b0
Custom Metal Kernels from Python (#1325)
* start

* simple kernels working

* restructure

* inverse example working

* docs + fixes

* missing file

* fix imports

* address comments

* add docs + fix test

* Review comments + refactor to a single function

* update docs

* remove hashing

* fix contig bug in test

* back to a class

* trailing whitespace

* fix tests

* match c++ and python apis

* add link + make args kw_only
2024-08-22 13:46:29 -07:00
Awni Hannun
df3233454d
2d gather specialization (#1339) 2024-08-22 10:48:24 -07:00
Awni Hannun
8ae751d3da
fix io (#1343)
* fix io

* fix io

* comment
2024-08-21 13:14:46 -07:00
Awni Hannun
d40e76809f
Fix rope (#1340)
* add test

* fix rope

* fix test
2024-08-20 17:37:52 -07:00
Awni Hannun
bb1b76d9dc
RoPE with frequencies as optional input (#1337)
* start rope with freq input

* rope with frequencies

* nits

* fix bug

* fix bug + test

* cleanup

* optional base
2024-08-19 18:30:50 -07:00
Angelos Katharopoulos
9d26441224
Fix contiguity check (#1336)
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-08-19 16:05:06 -07:00
Awni Hannun
f12f24a77c
fix compiling with space in paths (#1332) 2024-08-15 16:39:24 -07:00
Awni Hannun
d0630ffe8c
Read arrays from files faster (#1330)
* read faster

* faster write as well

* set default permission for linux

* comment
2024-08-14 20:09:56 -07:00
Alex Barron
99bb7d3a58
GPU mx.sign for complex64 (#1326) 2024-08-14 07:54:53 -07:00
Awni Hannun
eaaea02010
Add isfinite (#1318)
* isfinite

* remove reduce test since fix is not complete
2024-08-13 14:49:28 -07:00
Brian Keene
19fb69e2ed
Add memory_efficient_threshold kwarg to sdpa kernel (#1319)
Allows opt-in to memory efficient GPU shader at proscribed sequence
length.  Otherwise, utilizes aggregate MLX primitives for best latency.
2024-08-12 12:57:09 -07:00
Alex Barron
32668a7317
CPU mx.linalg.cholesky_inverse and mx.linalg.tri_inv (#1307)
* add cholesky inv + tri inv

* always run tri_inv on cpu

* consistent naming
2024-08-08 15:18:02 -07:00
Angelos Katharopoulos
eb8819e91e
Revert variance to be numerically stable (#1314) 2024-08-08 13:35:02 -07:00
Awni Hannun
30bbea2f08
Add gemv masked to JIT plus some fixes (#1310)
* add gemv masked to JIT plus some fixes

* some cleanup

* add utils

* fix

* fix 2

* more cleaning

* fix

* remove unused mps matmul support

* one more nit

* revert
2024-08-07 13:38:07 -07:00
Alex Barron
635ccd9e25
Add "edge" mode to mx.pad (#1309)
* Add edge padding mode

* fix pad in pooling

* string arg instead of enum
2024-08-06 11:23:10 -07:00
nicolov
8c9f0278b9
Add vmap to scatter (#1200)
* Add vmap to scatter

* updates

* vmap updates + a few more tests

* bug fix

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-08-05 20:12:27 -07:00
Awni Hannun
58d0e199e1
add bfloat conv for windograd (#1306)
* add bfloat conv for windograd

* accumulate in fp32

* accumulate in fp32

* accumulate in bf16
2024-08-05 15:51:13 -07:00
Awni Hannun
43ffdab172
fix rope and random (#1301)
* fix rope and random

* comment
2024-07-31 16:18:25 -07:00
Awni Hannun
40b6d67333
Fixes for large arrays with a few ops (#1299)
* fixes for large arrays with a few ops

* fix bug

* fix all of copy
2024-07-30 17:18:39 -07:00
Alex Barron
c52d1600f0
Fused Affine Quantize/Dequantize ops (#1282)
* Add fast affine dequantize

* add full quantize kernel

* fused kernel with scale/bias computation

* fix docstring

* fix no jit error

* fix test

* test fix

* reduce fast api to only affine_quantize
2024-07-29 15:11:38 -07:00
Awni Hannun
7b456fd2c0
Array api (#1289)
* some updates for numpy 2.0 and array api

* some updates for numpy 2.0 and array api

* fix array api doc
2024-07-26 10:40:49 -07:00
Anton Belov
5029894662
[Issue #1187] Add nan_to_num function initial attempt (#1247)
* initial attempt, working with wrong types

* not compiling; mx.float16 and mx.bfloat16 tests added

* fix nan to num

* nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-07-25 09:57:37 -07:00
Awni Hannun
baf9fa5f42
Einsum (#1269)
* einsum initial

* fix comma break

* sum axis was wrong

* small cleanups

* python binding

* changed bindings to resemble numpy

* remove todo comment

* comment changes

* add count of operands/inputs

* fail fast if operands list is empty

* ignore comma if no output

* einsum path matching numpy

* getting somewhere with path

* remove print

* it passes the first test

* moved einsum tests to seperate file

* seperated einsum path

* moved einsum naive

* remove space from equation

* fast fail if no operands passed

* update tests and remove printf

* small cleanup

* some more cleanups

* removed python helper file

* ack

* utilize std for finding min in vector

* duplicate def

* remove the tuple as it was unreadable

* moved einsum_naive back to ops

* remaining isn't needed

* avoid creating another set

* cleanup

* greedy path, start of naive einsum

* more einsum

* fix some bugs

* some more fixes, tests pass

* benchmark

* some simplify

* fix einsum and test

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>

* add a bunch more tests and fix a bunch more bugs

* some docs nits

---------

Co-authored-by: dc-dc-dc <dgcruz983@gmail.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-07-25 09:36:44 -07:00
Jagrit Digani
7f914365fd
Fix GPU sort for large arrays (#1285)
* Fix GPU sort for large arrays
2024-07-24 14:37:10 -07:00
fgranqvist
50eff6a10a
Implement sampling from laplace distribution. (#1279) 2024-07-24 15:15:37 +02:00
Alex Barron
c34a5ae7f7
Fix bfloat16 Hadamard (#1283)
* fix bfloat16 hadamard

* add scale

* review comments

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-07-23 14:54:43 -07:00
Awni Hannun
e2aa6ec8ae
some fixes (#1281) 2024-07-23 11:49:05 -07:00
Tim Gymnich
6307d166eb
Fix overflow / underflow handling for expm1f (#1278)
* Fix overflow / underflow handling for expm1f

* update tests
2024-07-23 07:29:06 -07:00
Awni Hannun
1fba87b0df
Fix leak with multi-output primitives (#1274)
* fix leak with multi-output primitives

* hopefully an actual fix
2024-07-23 06:34:18 -07:00
Awni Hannun
df124e018a
fix gguf (#1273)
* fix gguf

* comment
2024-07-18 07:35:35 -07:00
Cheng
2f83d6e4b7
Do not release buffers on exit (#1142) 2024-07-15 15:12:24 -07:00
Feng Shijie
987785d8d7
Fix typo and missing header (#1266) 2024-07-15 08:20:24 -07:00
Angelos Katharopoulos
5c1fa64fb0
Custom transforms (#1246) 2024-07-10 18:00:01 -07:00
Alex Barron
a3c287354f
Fast Hadamard Transform (#1249)
* Working hadamard for powers of 2

* working for m*2^k

* add scale and check contiguity

* add size check

* clean up

* fix test

* add grads + vmap

* gpu only

* skip on linux

* test typo

* add cpu impl

* remove gpu only tests

* fix linux build + add is_equivalent
2024-07-09 20:39:01 -07:00
Angelos Katharopoulos
03cf033f82
Fix reshape copy bug (#1253) 2024-07-07 21:37:00 -07:00
Alex Barron
bdb36c9a63
add zero vjps for bitwise ops and gather w.r.t. index (#1256) 2024-07-07 21:34:59 -07:00
Awni Hannun
20bb301195
CPU binary reduction + Nits (#1242)
* very minor nits

* reduce binary

* fix test
2024-06-28 13:50:42 -07:00
Angelos Katharopoulos
b05bcfd27f
Fixes segfault when compiling checkpointed functions (#1235) 2024-06-26 16:14:45 -07:00
Alex Barron
2615660e62
Fix strided sort bug (#1236)
* Use output strides in sort kernel

* fix zero strides bug
2024-06-26 14:32:11 -07:00
Awni Hannun
5b0af4cdb1
fix donation condition for compilation (#1237) 2024-06-26 09:04:05 -07:00
Jagrit Digani
8c2e15e6c8
Accelerate import updates for iOS (#1227)
* Update veclib and bnns includes to #include <Accelerate/Accelerate.h> for compatibility with ios

* Mark float literals in softmax.cpp to be float16_t for errors in ios

* Add arm neon vector operation guards

* Redirect to common backend for consistency
2024-06-26 09:01:50 -07:00
Awni Hannun
56c8a33439
Get metal version from xcode (#1228)
* get metal version from xcode

* typo

* fix
2024-06-26 07:02:11 -07:00
Jagrit Digani
2d6cd47713
Masked gemv (#1211) 2024-06-14 09:52:26 -07:00
Awni Hannun
fe3167d7ea
smaller CPU binary (#1203)
* smaller CPU binary

* fix no cpu build
2024-06-14 09:46:55 -07:00
Awni Hannun
31e134be35
Build for macOS 15 (#1208)
* Build for macos 15

* metal32 as well

* comment

---------

Co-authored-by: Awni Hannun <Awni Hannun>
2024-06-13 13:31:44 -07:00
Fangjun Kuang
f20e97b092
minor fixes (#1194)
* minor fixes

* fix build errors
2024-06-12 22:06:49 -07:00
Alex Barron
934683088e
Refactor JIT for unary/binary/ternary ops (#1206)
* refactor unary/binary/ternary ops

* get_primitive_string util

---------
2024-06-12 14:22:12 -07:00
Awni Hannun
de2b9e7d0a
Fix kernel deps to reduce build times (#1205) 2024-06-12 11:17:39 -07:00
Alex Barron
dd7d8e5e29
Add Quantized Ops to the JIT (#1204)
* JIT for quantized ops

* remove unused imports

* address comments

* fix imports

* second attempt to fix imports

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-12 09:47:12 -07:00
Awni Hannun
df964132fb
fix scatter + test (#1202)
* fix scatter + test

* fix test warnings

* fix metal validation
2024-06-11 14:35:12 -07:00
Alex Barron
27d70c7d9d
Feature complete Metal FFT (#1102)
* feature complete metal fft

* fix contiguity bug

* jit fft

* simplify rader/bluestein constant computation

* remove kernel/utils.h dep

* remove bf16.h dep

* format

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-06 12:57:25 -07:00
nicolov
0e585b4409
Add docstring for scatter (#1189)
* Add docstring for scatter

* docs nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-06-06 11:51:25 -07:00
Angelos Katharopoulos
0163a8e57a
Add docs for the distributed namespace (#1184) 2024-06-06 11:37:00 -07:00
Awni Hannun
578842954c
fix jit scan when output doesn't have primitive (#1190) 2024-06-06 07:24:58 -07:00
Awni Hannun
496315fe1d
Fix scan (#1188)
* fix scan

* improve grid size

* fix cpu cummax
2024-06-05 14:21:58 -07:00
Awni Hannun
83b11bc58d
Fix Metal API validation for empty concat (#1183) 2024-06-04 13:17:08 -07:00
Alex Barron
375a8bbdcc
Add some internal GPU apis (#1177)
* Add unary/binary/ternay/slice/concat internal GPU ops

* add pad internal op

* formatting + no_cpu fix
2024-06-04 09:24:26 -07:00
Awni Hannun
ea9090bbc4
Add view op (#1179)
* add view primitive

* nit

* fix view
2024-06-04 08:05:27 -07:00
Angelos Katharopoulos
3de8ce3f3c
In place all-reduce and forgiving init (#1178) 2024-06-03 16:47:47 -07:00
Alex Barron
4d485fca24
Add defines include (#1176)
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-03 09:50:10 -07:00
Brian Keene
1865299a30
Metal shaders for memory efficient self attention on large sequences (#964)
* Metal shaders for efficient self attention on large sequences

Updated fast attention: GEMM-ified with Steel primitives
Uses flash attention 1 for scale correction

* more compiler silencing

* Address rebase issues

* Templatize kernel instantiation, revise cpu bindings

* Safer writes to output

* Permit batch size > 1

* Numerical fixes for sdpa self attention

* Re-enable test, remove unused variable

* add benchmarking script

* Disable sdpa prior to perf tuning, and simplify tests for per-patch CI
2024-06-03 09:16:19 -07:00
Awni Hannun
fd1c08137b
stable cumprod grad at 0 (#1167) 2024-05-31 12:28:42 -07:00
Jagrit Digani
76b6cece46
Fix multi-block sort stride management (#1169)
* Fix multi-block sort stride management

* Add seed to tests
2024-05-31 11:10:54 -07:00
Jagrit Digani
9f0df51f8d
Fix matvec vector stride bug (#1168) 2024-05-29 12:18:28 -07:00
Awni Hannun
e7a2a3dcd1
Fix a couple bugs (#1161)
* fix jit reduce for RMS norm

* make strides a single buffer

* better eval error message

* fix compiling with inf and bf16

* fix cpu compile with bf16
2024-05-28 15:18:18 -07:00
Awni Hannun
a87ef5bfc1
fix broadcast bug in bitwise ops (#1157) 2024-05-24 11:44:40 -07:00
Awni Hannun
7e26fd8032
Option to JIT steel gemm / conv (#1139) 2024-05-23 18:07:34 -07:00
Jagrit Digani
eab2685c67
Float mask update (#1152)
* Float mask update

* Update CPU impl
2024-05-23 17:20:44 -07:00
Angelos Katharopoulos
50dfb664db
Comms (#1097)
* Start the communications branch using MPI
* Add ops and primitives
* Add python bindings for distributed
2024-05-23 17:04:02 -07:00
Awni Hannun
0189ab6ab6
More jitting (#1132)
* docs + circle min size build

* jit scan, arange, softmax

* add sort

* jit reductions

* remove print

* fix deps

* clean includes / nits
2024-05-23 16:23:44 -07:00
Rifur13
9401507336
Add groups to 2-D convolutions (#1129)
* Added groups to 2-D convolutions. Only implemented for **some** specializations.

Also fixed 1D grouped convs with different kernel strides and added more tests.

* fix channels condition
2024-05-22 20:01:44 -07:00
Abe Leininger
79ef49b2c2
add mx.trace (#1143) (#1147)
* working c++ trace implementation

* updated throw + added overloads

* added python binding for trace function

* pre-commit reformatting

* add trace to docs

* resolve comments

* remove to_stream call
2024-05-22 15:50:27 -07:00
Awni Hannun
e110ca11e2
Fix offset bug for device buffers (#1151)
* fix bug with large offsets for buffers

* add a test

* remove test as its too big for small machine
2024-05-22 15:50:05 -07:00
Awni Hannun
226748b3e7
JIT compile option for binary minimization (#1091)
* try cpp 20 for compile

* unary, binary, ternary in jit

* nits

* fix gather/scatter

* fix rebase

* reorg compile

* add ternary to compile

* jit copy

* jit compile flag

* fix build

* use linked function for ternary

* some nits

* docs + circle min size build

* docs + circle min size build

* fix extension

* fix no cpu build

* improve includes
2024-05-22 12:57:13 -07:00
Awni Hannun
d568c7ee36
Rename block sparse (#1149)
* block_sparse_mm to gather_mm

* rename

* nit

* nit
2024-05-22 07:48:34 -07:00
Angelos Katharopoulos
da83f899bb
Improve qvm speed (#1140) 2024-05-20 09:20:44 -07:00
Awni Hannun
fb71a82ada
Fix copy bug with many dims (#1137) 2024-05-17 21:10:03 -07:00
Awni Hannun
23406c9e9e
Choose the right MLX bf16 for extensions (#1135)
* default to custom bf

* choose right bf

* fix extensions

* fix circle conf
2024-05-17 15:09:28 -07:00
Luca Arnaboldi
b3ec792380
Implemented Cholesky on CPU (#1119) 2024-05-17 12:31:59 -07:00
Angelos Katharopoulos
e78a6518fa
Block sparse qmm (#1124) 2024-05-16 15:24:14 -07:00
Awni Hannun
1873ffda01
Detect metal version and propagate correctly for JIT (#1109)
* detect metal version and propagate correctly for JIT

* remove softmax

* fix versions
2024-05-15 17:42:09 -07:00
Jagrit Digani
358e1fd6ab
Fused GEMM (#1123)
* Basic gemm working

* Update addmm

* Clear out steel_gemm and steel_addmm kernels

* Fuse and clear out gather gemm

* Update objc releases
2024-05-15 10:30:41 -07:00
Awni Hannun
863039da4c
Allow scatter type exception to be caught by checking in op (#1077)
* allow exception to be caught in main thread

* only for gpu

* more detailed scatter error
2024-05-13 17:43:53 -07:00
Awni Hannun
7178ac0111
No CPU option for binary minimization (#1105)
* no cpu build option

* docs

* fix
2024-05-13 16:08:11 -07:00
Max-Heinrich Laves
ff4223904d
Conv3d (#993)
* added conv3d

added conv3d

implemented explicit_gemm_conv_ND_cpu and bounds checks for slow_conv_3D

* incorporated reviewer comments

* fixed test

* reduced tensor shapes in test for conv3d

* Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion
2024-05-11 06:15:02 -07:00
Awni Hannun
a9f80d60f6
improve error messaging in eval (#1101) 2024-05-10 10:04:07 -07:00
Alex Barron
2e158cf6d0
Add conjugate operator (#1100)
* cpu and gpu impl

* add mx.conj and array.conj()

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-05-10 07:22:20 -07:00
Awni Hannun
8b1906abd0
Add compiler flags to disable safetensors and gguf (#1098)
* with docs

* nit
2024-05-09 17:39:44 -07:00
Awni Hannun
06375e6605
Split encoders in non-concurrent context with a max ops per encoder (#1085)
* split encoders

* fix race
2024-05-09 16:21:02 -07:00
Rahul Yedida
cc05a281c4
Added ArcTan2 operation (#1079)
* Added ArcTan2 operation

* Cleanup, bug fixes from code review

* Minor cleanup, fixed Linux tests
2024-05-08 08:35:15 -07:00
Jagrit Digani
fe96ceee66
Update block offset adjustment to be in size_t (#1087) 2024-05-08 08:10:23 -07:00
Awni Hannun
21623156a3
Reset peak memory (#1074)
* reset peak memory

* fix linux

* nits in docs
2024-05-03 17:12:51 -07:00
Awni Hannun
b00ac960b4
change initial memory limits and add memory size to device info (#1064) 2024-05-03 06:50:15 -07:00
Jagrit Digani
f390957685
Block sparse mm (#1058) 2024-05-02 14:03:58 -07:00
Angelos Katharopoulos
17f57df797
Improvements in the quantizer and dequantization kernel (#1061) 2024-05-01 18:19:11 -07:00
Awni Hannun
7f7b9662ea
Fix leak for multi-output primitives which are never detached (#1059)
* fix multi output leak

* ignore arrays that will be detached

* add some comments

* stray print
2024-05-01 07:31:45 -07:00
Awni Hannun
19bef39f5c
Add a mx.metal.device_info (#1060)
* device inof

* add variant

* fix linux

* fix doc
2024-04-30 15:47:27 -07:00
Nripesh Niketan
a30e7ed2da
feat: metal formatting and pre-commit bump (#1038)
* feat: metal formatting and pre-commit bump

* add guards

* update

* more guards

* more guards

* smakk fix

* Refactor instantiation of ternary types in ternary.metal

* fix scan.metal
2024-04-30 07:18:09 -07:00
Angelos Katharopoulos
8db7161c94
Bug fix in quantize (#1054) 2024-04-29 20:55:04 -07:00
Awni Hannun
09f1777896
fix slice update indexing (#1053) 2024-04-29 12:17:40 -07:00
Rifur13
c4a471c99d
Add groups to Conv1d (#948)
* Add conv1d grouped convs on CPU

* Add GPU support

* Parallelize inside metal kernel

* clenaup

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* New unfold kernel + remove unused code

* Remove copy and refactor

* Update vjp and reuse steel gemm

* Fixed groups on cpu

* Fix metal validation

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-04-27 06:24:57 -07:00
Awni Hannun
86f495985b
Add bitwise ops (#1037)
* bitwise ops

* fix tests
2024-04-26 22:03:42 -07:00
Awni Hannun
67d1894759
fix order device -> scheduler (#1039) 2024-04-26 13:46:41 -07:00
Awni Hannun
5bfe89bdb1
Cpp docs (#1036)
* start of C++ docs

* fix stream doc

* only include ops for now
2024-04-26 12:56:05 -07:00
Awni Hannun
771575d27b
Expose function to clear memory cache (#1032)
* expose function to clear memory cache

* fix linux build

* fix metal tests
2024-04-24 16:48:51 -07:00
Angelos Katharopoulos
20a01bbd9f
Simplifying and improving qmm (#1030) 2024-04-24 13:07:45 -07:00
Angelos Katharopoulos
ec8578d41a
Fix quantization of all 0s (#1028) 2024-04-24 00:40:42 -07:00
Aneesh Shetty
d0dbfe0b97
Adds radians and degrees (#1011) 2024-04-22 11:17:49 -07:00
Awni Hannun
3d405fb3b1
Add synchronize function (#1006)
* add synchronize function

* fix linux

* fix linux

* fix and fix docs

* fix test

* try synchronize in stream destroy

* synchronize works for both cpu and gpu
2024-04-22 08:25:46 -07:00
Angelos Katharopoulos
84d61d27aa
Make sure 0 is represented in the quantization (#1016) 2024-04-19 19:47:26 -07:00
Awni Hannun
ed83908931
fix gguf loading quants (#1014)
* fix gguf loading quants

* fix nanobind install

* actual fix
2024-04-19 12:24:07 -07:00
Jagrit Digani
85c8a91a27
Fix mask broadcasting bug and add relevant test (#1003) 2024-04-17 17:33:48 -07:00
Awni Hannun
8a0677d56d
Shared events for synchronization + async eval (#998)
* more async eval

* fix rebase

* try correct async eval

* fix async

* more tests for async eval

* use shared events for synchronization

* comment + cleanup

* with autorelease pool

* fix no metal build

* fix compile

* fix patch

* don't eval if asyn evale'd

* don't use is_evaled

* comments

* more multi stream tests

* try and cleanup use of is_evaled

* use a status flag
2024-04-17 06:16:02 -07:00
Jagrit Digani
b18468bf81
Masked mm (#978)
* Add block masked matmul op and primitive
2024-04-16 14:45:39 -07:00
Alex Barron
2e7c02d5cd
Metal FFT for powers of 2 up to 2048 (#915)
* add Metal FFT for powers of 2

* skip GPU test on linux

* fix contiguity bug

* address comments

* Update mlx/backend/metal/fft.cpp

* Update mlx/backend/metal/fft.cpp

* fix bug in synch

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-11 21:40:06 -07:00
Awni Hannun
ae18326533
No copy command encoder (#986)
* no copy command encoder

* up layer norm test tolerances
2024-04-11 21:15:36 -07:00
Angelos Katharopoulos
dce4bd74a4
Add ArrayDesc destructor to avoid possible stack overflow (#982) 2024-04-11 11:37:02 -07:00
Nripesh Niketan
ffff671273
Update pre-commit hooks (#984) 2024-04-11 07:27:53 -07:00
Awni Hannun
12d4507ee3
Explicit barriers with concurrent dispatch (#977) 2024-04-10 21:45:31 -07:00
Awni Hannun
8580d997ff
Try a stack-based DFS for eval (#980)
* rebase

* nit

* fix eval in vmap
2024-04-10 17:05:13 -07:00
Awni Hannun
99abb9eff4
Async eval (#972) 2024-04-09 18:34:00 -07:00
Luca Arnaboldi
fffe072028
Implementation of mlx.random.multivariate_normal (#502) (#877)
* Implementation of mlx.random.multivariate_normal (#502)

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Updated typo in docstring

* Restricted multivariate_normal to  float32

* Generic mean and variance shapes

* Review edits

* Update mlx/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Test for ndim of mean and cov

* nits

* smaller size for test

* fix broadcasted sampling

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-09 13:50:12 -07:00
Abe Leininger
a1a31eed27
Add mx.meshgrid (#961) 2024-04-09 11:43:08 -07:00
Awni Hannun
ae812350f9
use string (#976) 2024-04-09 11:22:00 -07:00
Awni Hannun
42afe27e12
std and expm1 (#973)
* std and expm1

* actually add expm1

* fix linux

* fix vjp

* relax tol for linux test

* Add it to the compilable primitives

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-04-08 14:26:01 -07:00
Awni Hannun
76e63212ff
Enable bfloat scan (#974)
* enable bfloat scan
* fix tests
2024-04-08 12:29:19 -07:00
Awni Hannun
aac2f9fb61
Improve profiling with gpu tracing (#969)
* improve profiling with gpu tracing

* fix for linux

* nit

* doc fix

* fix example
2024-04-07 21:47:43 -07:00
Awni Hannun
039da779d1
No quant reshape (#957)
* precise option on cpu

* remove print

* remove reshape in quant matmul

* no quant reshape
2024-04-04 11:52:12 -07:00
Awni Hannun
d88d2124b5
segfaut layer norm grad (#955) 2024-04-04 10:59:15 -07:00
Awni Hannun
e142aaf8a1
Option for precise softmax (#953)
* precise softmax

* Add an equivalency check

* Make the threadgroup memory definition fixed

* precise cpu softmax

* precise option on cpu

* remove print

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-04-04 08:32:35 -07:00
Awni Hannun
741eb28443
fix a couple bugs (#952) 2024-04-02 12:07:41 -07:00
Angelos Katharopoulos
1a87dc5ea8
Fix compile fusion for multi-output edge cases (#950)
* Fix compile fusion for multi-output edge cases

* Add a test for multi-output compile
2024-04-02 08:42:31 -07:00
Awni Hannun
2427fa171e
Fix cpu compile (#934)
* fix one cpu bug, test for another

* format hooks

* simplify contiguity check for cpu compile

* fix

* add back donation

* comment
2024-04-01 17:37:12 -07:00
Angelos Katharopoulos
110d9b149d
Layer norm grad fix donation bug (#941)
* add layer norm grad test

* Fix donation bug in layernorm vjp

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-01 06:15:50 -07:00
Angelos Katharopoulos
9cbff5ec1d
Fix typo in qmm check (#940) 2024-03-31 19:15:44 -07:00
Awni Hannun
8915901966
Donation bug (#933)
* donation

* buf

* fix bug in softmax

* comment

* remove print
2024-03-30 10:08:54 -07:00
Cheng
913b19329c
Add missing && when forwarding args (#925)
Without the && args would be copied and perfect forwarding won't work.
2024-03-29 06:48:29 -07:00
Angelos Katharopoulos
5f9ba3019f
Fix qmm_t for unaligned cases (#923) 2024-03-28 15:34:57 -07:00
Cheng
46caf0bef0
Remove unnecessary string copies (#891)
1. Use string_view instead of string when there is no need for copy.
2. Otherwise move string when possible.
2024-03-28 13:14:59 -07:00
Jack Mousseau
45f636e759
Add Metal debug option and capture functions (#707)
* Add Metal debug option and capture functions

* Add brief Metal debugger documentation

* doc nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-28 09:40:31 -07:00
Cheng
a7b404ff53
Use uintptr_t instead of size_t to store funtion id (#916)
Also does some small cleanup of the compile cache code.
2024-03-28 06:37:59 -07:00
Cheng
bab5386306
Make ops aware of rvalues: astype/as_strided/copy/full (#895)
When compositing transforms lots of temporary of arrays will be created
and passed to next primitive, and by making ops accepting args by value
we can avoid lots of copies of temporary arrays.
2024-03-27 22:35:55 -07:00
Angelos Katharopoulos
aca7584635
Fix OOB read in qmv when non-divisible by blocksize (#917) 2024-03-27 22:18:35 -07:00
Cheng
90dfa43ff1
Don't use make_unique to create shared_ptr (#902)
The code compiled because shared_ptr's constructor actually accepts
unique_ptr.
2024-03-27 06:13:29 -07:00
Awni Hannun
dc175f08d3
Fix race in multi-stream eval (#911)
* maybe fix race

* comment
2024-03-26 16:36:36 -07:00
Angelos Katharopoulos
29221fa238
Implement vjps for some primitives in the fast namespace (#883)
* Implement rope vjp in terms of rope
* RMSNormVJP primitive and kernel
* Add LayerNormVJP primitive and kernel
2024-03-26 16:35:34 -07:00
Cheng
a789685c63
Remove duplicate defines of StreamOrDevice and is_big_endian (#892) 2024-03-26 15:15:11 -07:00
Jagrit Digani
240d10699c
Implement negative padding in conv with slicing (#907)
* Implement negative padding with slicing

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni@apple.com>

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-26 14:59:19 -07:00
Jagrit Digani
925014b661
Fix multiblock sort limits (#906)
* Fix multiblock sort limits

* Fix metal validation error
2024-03-26 14:00:00 -07:00
Angelos Katharopoulos
9948eddf11
Fix nan and improve speed for qvm (#903) 2024-03-26 10:41:45 -07:00
Luca Arnaboldi
a3ee03da01
Fixing random.normal for half-precision dtype #642 (#904)
* Fixing random.normal for half-precision dtype #642

* Update python/tests/test_random.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-26 09:58:27 -07:00
Cheng
28fcd2b519
Add missing && when forwarding args (#894)
Without the && args would be copied and perfect forwarding won't work.

Also add template utils to make sure the function only forwards array
and not vector<array>.
2024-03-25 14:55:54 -07:00
Jack Mousseau
8e686764ac
Ensure shape dimensions are within supported integer range (#566) (#704)
* Ensure shape dimensions are within supported integer range (#566)

* fix build

* fix rebase bug

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-25 13:29:45 -07:00
Daniel Strobusch
479051ce1c
add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)
* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate

numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)).

Closes #285.

* nits in docs

* unify type category checking

* nits in docs

* nits in docs

* more docs nits

* fix callable type

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-25 12:32:59 -07:00
Awni Hannun
be98f4ab6b
Reduce a little overhead (#871)
* some small overhead improvements

* use result_type in rms_norm

* remove release force

* fix + use non-vector version

* revert compile change

* fix ops

* a little more overhead

* a little more cleanup and overhead
2024-03-22 17:29:36 -07:00
Angelos Katharopoulos
6ee1112f30
Fix copy donation and add partial rope (#881) 2024-03-22 17:28:26 -07:00
Cheng
9663c22fe9
Do not store iostream in shared_ptr (#872)
There is no need to store iostream in shared_ptr, doing so adds the cost
of a heap allocation.
2024-03-22 06:54:45 -07:00
Cheng
f0ae00da12
Reduce implicit copies in make_array (#874)
1. Move shapes into outputs instead of copying them.
2. Pass primitive by const ref as it is always copied into outputs, which
   removes a copy when calling make_array.
2024-03-22 06:29:16 -07:00
Angelos Katharopoulos
2225374060
Adds mx.fast.layer_norm (#870) 2024-03-21 13:55:51 -07:00
nicolov
105d236889
Add vmap for SVD and inverse (#849) 2024-03-21 13:18:27 -07:00
Awni Hannun
a54f06b16f
Fast RMS Norm (#862)
* fast rmsnorm

* no rms gpu

* kernel

* fix shared mem

* looped rms and donation in softmax

* Make the squaring in float32 to avoid underflow

* Fix the default StreamOrDevice for rope and rms_norm in fast

* nits

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-03-21 07:20:54 -07:00
Cheng
4650d94d98
Add missing && in eval (#864)
Without the && args would be copied and perfect forwarding won't work.

To avoid eval calling itself recursively, the vector version of eval is
changed to take by value instead, which will save a copy of array when a
rvalue is passed.
2024-03-21 06:15:48 -07:00
Jagrit Digani
a5681ebc52
Update set item (#861)
* Update mlx_set_item to handle regular slices without expanding

* Refactor ellipsis handling

* Route mlx_set_item to slice_update where possible

* Update mlx_scatter_args_slice

* Don't route to gather if no array indices
2024-03-21 02:48:13 -07:00
Cheng
e849b3424a
Do not use static constexpr in header (#863)
Doing so results in each compilation unit (.cpp file) having its own
copy of the variable, while inline constexpr makes sure there is only
one copy.
2024-03-20 21:28:05 -07:00
Jagrit Digani
b219d12a6b
Check edge case handling in row reduce med kernel (#858) 2024-03-20 11:37:58 -07:00
Jagrit Digani
cec8661113
Add a SliceUpdate op and primitive (#850)
* Enable copy to work with int64 strides
* Fix uniform buffer indices or copy kernel arguments
* Update utils.h
* Remove manual unrolling of elem to loc loop
* GPU copy updated to handle negative strides
* Add slice update primitive
2024-03-20 10:39:25 -07:00
Cheng
73a8c090e0
Pass shape and inputs by value in array's constructor (#853)
Since the shape and inputs are always saved as copy in ArrayDesc, we can
unify array's constructors to just take the arguments by value.

There are 2 cases:
1. When shape is a lvalue, it will be copied into array's constructor and
   then moved into ArrayDesc's member. So only 1 copy happens.
2. When shape is a rvalue, it will be moved into array's constructor and
   then moved into ArrayDesc's member. So no copy happens.

So having 1 constructor that takes by value is equivalent to having 2
constructors that const reference and rvalue separately.
2024-03-20 07:54:30 -07:00
Awni Hannun
9a8ee00246
Switch to nanobind (#839)
* mostly builds

* most tests pass

* fix circle build

* add back buffer protocol

* includes

* fix for py38

* limit to cpu device

* include

* fix stubs

* move signatures for docs

* stubgen + docs fix

* doc for compiled function, comments
2024-03-18 20:12:25 -07:00
Cheng
d39ed54f8e
Some C++ code are not needed (#841)
1. Anonymous namespace means internal linkage, static keyword is not needed.
2. The default constructor of std::shared_ptr initializes the pointer to
   nullptr, you don't need to explicitly set it.
2024-03-18 17:04:10 -07:00
Awni Hannun
16546c70d8
No reshape rope (#838)
* no reshape rope

* no reshape rope
2024-03-18 17:03:07 -07:00
nicolov
eaba55c9bf
Add matrix inversion primitive (#822) 2024-03-15 06:34:36 -07:00
Awni Hannun
19ec023256
vmap matmul and admm (#836) 2024-03-14 14:38:22 -07:00
Jagrit Digani
8dfc376c00
Strided reduce specialization for small reductions (#826)
* Add small column / general reduction specialization
2024-03-14 09:16:53 -07:00
Angelos Katharopoulos
1efee9db09
Add types and order in kernel name (#831) 2024-03-13 20:34:06 -07:00
Awni Hannun
43abc402d8
route to fallback (#828) 2024-03-13 19:56:04 -07:00
Angelos Katharopoulos
3f8b1668c4
Make reshape faster for row_contiguous cases (#829) 2024-03-13 16:22:03 -07:00
Angelos Katharopoulos
76c919b4ec
NumberOfElements for shapeless compile and vmap fixes (#802) 2024-03-13 10:34:14 -07:00
Angelos Katharopoulos
29d0c10ee5
Reshape improvement (#818) 2024-03-12 17:54:31 -07:00
Jagrit Digani
5ad133f8bb
No copy gems (#801)
* Enable collapsing batch dims in gemm
* Update gemm to only make copies when neither of the last 2 axes are contiguous
* Update addmm to support gemv shapes
* Update addmm to support irregular batch strides
* Update tests
2024-03-12 13:13:41 -07:00
nicolov
d0c544a868
Add SVD primitive (#809)
Add SVD op using Accelerate's LAPACK following
https://developer.apple.com/documentation/accelerate/
compressing_an_image_using_linear_algebra

Co-authored-by: Nicolo Valigi <nvaligi@apple.com>
2024-03-12 12:30:11 -07:00
Awni Hannun
8b7532b9ab
fix scatter (#821) 2024-03-12 11:42:07 -07:00
Awni Hannun
0e95b64942
Fix bug in tape order during simplify (#816)
* fix bug in tape order during simplify

* properly fix compile

* last bug
2024-03-11 17:29:05 -07:00
nicolov
0ae22b915b
Remove code duplication in reduce ops (#793)
* Remove code duplication in reduce ops

* Remove the unnecessary lambda

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-03-11 10:57:07 -07:00
Awni Hannun
7c441600fe
Compile stride bug (#812)
* fix compile stride bug

* revert sdpa fix

* fix cpu

* fix bug with simplifying outputs
2024-03-11 06:31:31 -07:00
Awni Hannun
a4d290adb9
Remove depth traversal (#813)
* no depth traversal

* counter outside loop
2024-03-09 20:21:32 -08:00
Jagrit Digani
ec8a4864fa
Fix SDPA kernel bug on Mac OS 13.3 SDK (#805)
* Move sdpa kernel to allocate tgp mem statically and allow macOS 13.3 SDK builds

* Style
2024-03-07 10:18:09 -08:00
Awni Hannun
f512b905c7
Minimum xcode / sdk (#800)
* minimum xcode /sdk

* try multiple xcode versions in CI

* update python

* metal validation for python tests
2024-03-07 08:19:43 -08:00
Awni Hannun
afd5274049
route to fallback for bfloat (#794) 2024-03-06 15:39:12 -08:00
Awni Hannun
1074674e32
Add a maximum graph depth (#797)
* add a maximum graph depth

* remember how to use C++
2024-03-06 15:39:00 -08:00
Angelos Katharopoulos
e39bebe13e
Fix reshaping of empty arrays (#791) 2024-03-05 23:33:22 -08:00
Angelos Katharopoulos
14b4e51a7c
Improved quantized matrix vector product (#786) 2024-03-05 17:32:19 -08:00
Awni Hannun
cbcf44a4ca
Some fixes in cache / thread safety (#777)
* some fixes in cache / thread safety

* speed up no cache case

* fix opt test

* optimizer docs

* otpimizer docs

* fix adafactor

* fix adafactor
2024-03-05 13:30:50 -08:00
Brian Keene
0787724c44
Fast Inference SDPA op (#735)
* Fast Inference SDPA op

Implements metal shaders for:

o = mx.fast_inference_sdpa(queries, keys, values, scale, mask)

Supports fp16, fp32 dtypes; assumes d_k = 128.

Generic op support / prompt encoding supported via mlx primitives.
Metal implementation is for the inference use case only.

Majority of performance benefits appears to results from GQA & reduced
bandwidth requirements; there is approximate performance parity for the
MHA use case (from some measurements on M3 Max).

* Flush shared memory to zero before unprotected reads for (scores @ values)

* Move to fast:: namespace, address reviewer comments

... also attempt to revert formatter auto-change for files not relevant
to this change

* Shared memory flush to top of kernel

* Resolve compiler warnings

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update docstring per PR feedback

* Softmax in higher precision, ...

* route to fallback for more use cases - batch size > 1, head_dim other
  than 128, etc.
* Address linux build failure
* Address other reviewer comments

* Remove extraneous eval_cpu function per review

---------

Co-authored-by: Atila Orhon <64497909+atiorh@users.noreply.github.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: atila <atiorh@icloud.com>
2024-03-04 21:06:11 -08:00
Awni Hannun
7b463ffb07
Ios compile (#784)
* try to fix build for ios

* skip cpu compile

* fix namespace

* fix namespace

* Use CMake for platform specific cpu compile

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-03-04 20:02:26 -08:00
Jagrit Digani
6686e61ca4
Reduce update (#783)
* Split reduction files to reduce compile times

* Add small and medium axis size specializations for row reductions

* Add non-row-reduction options for small and med kernels
2024-03-04 19:09:51 -08:00
Awni Hannun
5121f028d9
nice tensordot for mlx c (#782) 2024-03-04 09:51:02 -08:00
Awni Hannun
bc06cb9ff6
Pickle + dtype fix for numpy conversion (#763)
* pickle + dtype fix for numpy conversion

* fix getattribute on Module base

* remove unused function

* fix tests

* add topk to ops

* fix doc
2024-03-02 06:09:29 -08:00
Angelos Katharopoulos
8e281c76c3
Fix the top-k op (#768) 2024-03-01 22:08:43 -08:00
Awni Hannun
d5964a2710
bindings for memory info (#761)
* bindings for memory info

* update api

* keep cache low if requested

* fix default

* nit in ops error
2024-03-01 19:51:58 -08:00
Jagrit Digani
776c3d226d
Convolution update (#651)
* Init steel conv and update Conv primitive

* Update slow CPU implementation to support flipping and input dilation winograd conv routing

Co-authored-by: Awni Hannun <awni@apple.com>
2024-02-28 20:11:16 -08:00
Awni Hannun
f5f18b704f
fix temporary bug (#752) 2024-02-27 17:44:39 -08:00
Awni Hannun
56ba3ec40e
fix cpu compile on older OS (#747) 2024-02-26 22:20:53 -08:00
Hinrik Snær Guðmundsson
08226ab491
added atleast *args input support (#710)
* added atleast list(array) input support

* function overloading implemented

* Refactoring

* fixed formatting

* removed pos_only
2024-02-26 11:17:59 -08:00
Awni Hannun
e6418781ab
Fix logsumexp edge case (#740)
* fix logsumexp

* fix inf constant

* also fix power grad

* fix ternary dispatch
2024-02-25 08:39:55 -08:00
Awni Hannun
ac02cf33bd
Fix some issues using MLX in C++ (#739)
* fix preamble build

* fix some issues with using MLX as a dep in C++
2024-02-24 22:20:57 -08:00
Noah Farr
d729a1991b
Fix arange with inf step (#686)
* Fix case for step=inf in arange and add inf check for start/stop

* Add test cases for arange

* Update ops.cpp to include climits header

* Fix arange

* Fix formatting

* Refactor

* Add missing include
2024-02-23 06:18:15 -08:00
Rifur13
126c9869c8
Implement the 'where' primitive for conditional selection (#664) 2024-02-22 15:10:48 -08:00
Jagrit Digani
884b4ed43b
Fix threadgroup memory in arg reduce (#723) 2024-02-21 19:42:16 -08:00
Vijay Krish
972d9a3aea
Up to 10x faster scatter. (#709)
* Faster scatter.

Add specialization for 1-d index tensors.

* Address review comments.

- Check for row contiguity of index, update tensors
  instead of checking strides.
- Add support for 1d specialization with col contiguous update
  tensor, along with a test.

* Nit1

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Nit2

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-02-21 11:09:30 -08:00
Awni Hannun
5798256fcf
Shapeless compilation for some graphs (#687)
* shapeless compilation for some graphs

* update compile benchmark

* default compile a few activations

* buffer donation

* bugfix

* shapeless fix

* update tests to work for cpu and gpu fusion

* test kwargs

* add kwargs to compile

* Recompile when python arguments change

* no compile for tanh

* some constant tests

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-19 21:43:54 -08:00
Hinrik Snær Guðmundsson
f883fcede0
Added support for atleast_1d, atleast_2d, atleast_3d (#694) 2024-02-19 09:40:52 -08:00
Awni Hannun
1a4f4c5ea6
Refactor CPU compile preamble (#708)
* refactor cpu preamble

* fix include order

* fix some issues'

* fixes for linux

* try to fix includes

* add back warning suppression

* more linux fixes
2024-02-19 06:12:53 -08:00
Jack Mousseau
0925af43b0
Remove unused variables (#706) 2024-02-18 12:50:10 -08:00
Awni Hannun
dc937b8ed3
CPU compile (#691)
* build and load shared object for cpu compile

* nits

* cpu compile tests pass

* cpu compile tests pass

* fix preamble for g++

* donation

* fix gpu buffer donation

* reuse prebuilt libraries

* faster contiguity conditoins

* fix test

* rid compiler warning

* fast erf

* Fix float16 for compile and add more types to cpu compile

* Remove a forgotten comment

* use cached libs

* nits

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-17 06:54:32 -08:00
Awni Hannun
c3965fc5ee
Separate fast ops and primitives (#699) 2024-02-16 19:16:39 -08:00
toji
85143fecdd
improved error msg for invalid axis(mx.split) (#685)
* improved error msg for invalid axis(`mx.split`)

* Apply suggestions from code review

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* fixed formatting issue

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-02-15 07:25:38 -08:00
Diogo
35431a4ac8
Adds device context manager (#679) 2024-02-14 14:14:58 -08:00
Awni Hannun
ccf1645995
Custom primitive + RoPE fat op (#676)
* extensions start

* rope custom op

* fix build

* docs + rope benchmark

* fix test

* Add a Metal kernel for RoPE

* Fix position of traditional

* transform tests

* Move rope computation to float and fix tests

* Fix the test and a typo

* change to fast

* fix no metal build

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-14 14:04:25 -08:00
Jagrit Digani
1a48713d32
Update gather and scatter to not use Argument Encoder (#683)
* Replace argument encoder usage for gather and scatter

* Use constant address space for shapes and strides

* Split gather and scatter to improve compile times

* Enable the GPU tests

* Update the CI config

* Fix scatter dispatch for scalar indices

* Remove arg encoder utils

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-14 13:42:13 -08:00
Awni Hannun
1eb04aa23f
Fix empty array construction in cpp (#684) 2024-02-13 23:34:17 -08:00
Noah Farr
0c65517e91
Return empty array when repeats is 0 in mx.repeat (#681)
* Return empty array when repeats is 0

* Add test case for repeats = 0
2024-02-13 17:49:31 -08:00
Vijay Krish
2fdc2462c3
Faster gather and scatter. (#682)
Reduce unnecessary integer ops, especially since
there kernels are integer bound.

Increase number of iterations for benchmarks for
better smoothing.

Github Issue #506

Co-authored-by: Vijay Krishnamoorthy <vijay_krish@apple.com>
2024-02-13 17:47:41 -08:00
Angelos Katharopoulos
40c108766b
Quantized matmul fix (#677)
* Fix qmv for small or unaligned matrices

* Fix qmm
2024-02-12 18:54:21 -08:00
Awni Hannun
3756381358
Faster bfloat quantized mat-vec and vec-mat (#663) 2024-02-11 21:53:16 -08:00
Awni Hannun
d12573daa6
quote file name (#670) 2024-02-11 10:33:30 -08:00
Vijay Krish
06072601ce
Scatter optimization : Eliminate 64b integer divide. (#662)
Launch 2D grid to eliminate divide and mod in device code,
since 64b integer division is very expensive.

Github Issue #506

Co-authored-by: Vijay Krishnamoorthy <vijay_krish@apple.com>
2024-02-10 08:49:51 -08:00
Awni Hannun
7f3f8d8f8d
Fix the softmax fix (#661) 2024-02-09 17:02:13 -08:00
Awni Hannun
b96be943dc
bug fix (#658) 2024-02-09 16:50:45 -08:00
Abdussamet Türker
b670485185
Remainder negative numerator bug fixed (#641)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-09 16:49:14 -08:00
Diogo
b57bd0488d
Metadata support for safetensors (#639)
* metadata support for safetensors

* aliases making it alittle more readable

* addressing comments

* python binding tests
2024-02-08 19:33:15 -08:00
Awni Hannun
1b97b2958b
Compile with capture (#629)
* Simple kernel generation

* Remove the generate kernel from graph_utils

* fix multi-output with compile

* fuse with stopgrad

* v1 input, output capture in compile

* cleanup tree update with visitor update

* nit

* remove todo

* state for model, optional explicit init and more pure optimizer steps

* move learning rate to state

* add lr to opt state, some fixes in capture

* fix optim

* update tuple of containers as well

* fix stream for compiled output

* rng state for compile

* nit

* updates and comments

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-07 17:29:22 -08:00
Angelos Katharopoulos
28eac18571
Kernel generation (#614)
Generate reusable element-wise kernels given a computation graph.
2024-02-07 13:15:59 -08:00
Noah Farr
5fd11c347d
Add loc and scale to random.normal (#638)
* Add loc and scale to random.normal

* Add tests for loc and scale for random.normal

* Run pre-commit hooks

* Fix code review
2024-02-07 11:49:59 -08:00
Awni Hannun
146bd69470
Skip compile when transforming (#635)
* skip compile when transforming

* simplify message
2024-02-05 21:28:37 -08:00
Jagrit Digani
316ff490b3
Remove masks from BlockLoader and clear out load case for invalid thread (#634) 2024-02-05 16:00:17 -08:00
Awni Hannun
d40a04f8dc
minor fixes (#631)
* minor fixes

* var with ddof >= nelements
2024-02-05 13:27:49 -08:00
Awni Hannun
d75ae52ecd
Compile primitive (#571)
* Compiled primitive with basic binary, unary graph-level fusion
2024-02-05 06:51:22 -08:00
Avikant Srivastava
31fea3758e
feat: enhancement of the error message for mlx.core.mean (#608)
* add error message
2024-02-05 01:21:49 -08:00
Awni Hannun
e319383ef9
Faster gather (#626)
* faster gather

* update copyright
2024-02-04 17:25:44 -08:00
David Koski
ebfd3618b0
fixes for building and running on iOS (#619)
* fixes for building and running on iOS

* per suggestion just use Accelerate
2024-02-04 12:29:17 -08:00
Avikant Srivastava
11a9fd40f0
fix: handle linspace function when num is 1 (#602)
* fix: handle linspace function when num is 1

* add comment

* fix test case

* remove breakpoint
2024-02-04 11:03:49 -08:00
Awni Hannun
95b5fb8245
minor changes (#613) 2024-02-02 11:48:35 -08:00
Awni Hannun
cb6156d35d
Fix eval in trace bugs (#612)
* Fix eval in trace bugs

* comment nit
2024-02-02 09:57:12 -08:00
Piotr Rybiec
506d43035c
typo fix (#607) 2024-02-01 17:39:55 -08:00
Awni Hannun
e88e474fd1
Reduce vmap + some fixes (#601) 2024-02-01 11:30:28 -08:00
Vijay Krish
fcc5ac1c64
Add GPU support for uint64/int64 reductions (#569) 2024-01-31 11:18:04 -08:00
Angelos Katharopoulos
199aebcf77
Change the variance computation (#319) 2024-01-30 19:28:56 -08:00
Angelos Katharopoulos
0de5988f92
Custom VJP and checkpointing (#541)
* Implement custom_vjp and checkpointing
* Add a dependency management primitive
* Change the eval order to deep branches first
* Add graph depth tracking to the array
2024-01-30 16:04:45 -08:00
Jagrit Digani
375446453e
Update Compute Pipeline Creation API (#581)
* Add option to specialize metal functions on function constants
* Update Compute Pipeline Creation API
* Add options to make libraries from source and stitching
* Update function specialization name options
2024-01-30 15:42:36 -08:00
Angelos Katharopoulos
1895d34c20
Fix log1p with inf inputs (#592) 2024-01-30 14:02:50 -08:00
Jacket
3f7aba8498
Implement diagonal operator (#562)
* Implement diagonal operator

This implements mx.diagonal in operator level, inspired by
@ManishAradwad.

* added `mx.diag` with tests

* corrected few things

* nits in bindings

* updates to diag

---------

Co-authored-by: ManishAradwad <manisharadwad@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-30 09:45:48 -08:00
Angelos Katharopoulos
65d0b8df9f
Fix binary op dispatch (#584) 2024-01-29 19:36:17 -08:00
Awni Hannun
3c2f192345
Propagate nans in binary ops (#579)
* propagate nans in binary ops

* handle empty matmul

* cpu minimum/maximum propagate nan

* benchmark maximum

* add min as well

* throw on negative indices with full

* verbose on linux

* fix matmul for zero K
2024-01-29 11:19:38 -08:00
Awni Hannun
8993382aaa
Buffer Donation (#519)
* buffer donation

* fix to move shared pointer

* format

* gpu in place for copy and binary

* revert ops test

* cpu in place

* a little cleanup

* remove useless bench
2024-01-26 16:30:33 -08:00
Awni Hannun
07f35c9d8a
Fix a few issues: docs for flatten, erf, dequantize validation (#560)
* doc flatten

* erf doc

* check values for dequantize

* format
2024-01-26 15:16:46 -08:00
Jagrit Digani
bf17ab5002
Add more checks and clearer error messages to conv operations (#563)
* Add more checks and clearer error messages to conv operations
2024-01-26 15:13:26 -08:00
Awni Hannun
8fa6b322b9
Compile front-end (#476)
* fix tests for linux

* make a move on compile

* basic compile scaffold works

* compile binding

* clean

* fix

* fix grad, more tests

* basic python tests

* fix segfault on python exit

* compile works with python closures

* fix test

* fix python globals bug, and erase

* simplify

* more cpp tests

* bug fix with move function and compile at exit

* simplify inputs also

* enable and disable compiler

* remove simplify

* simplify tests use compile now

* fix multi-output with compile

* clear output tree from cache when function goes out of scope

* ../python/src/transforms.cpp

* remove closure capture

* comments
2024-01-26 13:45:30 -08:00
taher
077c1ee64a
QR factorization (#310)
* add qr factorization

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-26 09:27:31 -08:00
Rifur13
2463496471
[Fix] mx.allclose bug with infinite values (#539)
* Added isclose op and fixed comparison with inf values

* Added 'equal_nan' to match numpy

* format

* Add test

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Addressed CR comments

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* nits

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-25 20:47:06 -08:00
Awni Hannun
f27ec5e097
More helpful error message in vjp transform + concate bug (#543)
* more helpful message in vjp transform

* fix concatenate on mismatch dims

* typo

* typo
2024-01-24 09:58:33 -08:00
Awni Hannun
f30e63353a
Minor updates to address a few issues (#537)
* docs on arg indices return type

* arange with nan

* undo isort
2024-01-23 22:24:41 -08:00
Juarez Bochi
4fe2fa2a64
GGUF: Avoid dequantization when format is compatible (#426)
* GGUF: Don't dequantize q4_1

* Fix weight order. First in low bits

* Add unpacking for q4_0

* Don't dequantize q8_0

* rebase quants and split file

* don't quantize every weight

* reapply patch

* error handling

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-23 15:43:57 -08:00
Jagrit Digani
6d3bee3364
Fix oob reads in gemv kernel (#523) 2024-01-22 12:06:04 -08:00
Awni Hannun
7a34e46677
Quantize with groups of 32 (#511)
* allow quantize with group sizes of 32

* missing cpu dispatch

* remove print

* Fix qvm for group_size 32

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-01-21 06:19:05 -08:00
Awni Hannun
b207c2c86b
Power VJP fix for 0 (#505) 2024-01-20 01:17:40 -08:00
Juarez Bochi
ddf50113c5
GGUF: Load and save metadata (#446)
* gguf metadata
---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-19 14:06:05 -08:00
Awni Hannun
c4ec836523
fix isinf for integer types (#494) 2024-01-19 05:31:10 -08:00