Commit Graph

308 Commits

Author SHA1 Message Date
Awni Hannun
3ae6aabe9f
throw for certain cases of non captured inputs in compile (#1401) 2024-09-09 14:54:31 -07:00
Max-Heinrich Laves
efeb9c0f02
Transposed Convolution (#1245)
* initial implementation for conv_transpose

ran pre-commit

implemented conv_transpose

updated conv_general docstring

updated conv_general docstring

updated code comments

removed commented run_conv_checks

updated acknowledgments

added missing entry to ops.rst

added op to nn.layers

resolved merge conflicts

* removed ConvolutionTranspose primitive as suggested by reviewer

removed ConvolutionTranspose primitive as suggested by reviewer

* remove transpose flag, add another test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-06 19:52:38 -07:00
Awni Hannun
ba3e913c7a
Simplifications for MLX C (#1396)
* simplifications for MLX C

* use vectors instead of map

* update examples
2024-09-06 19:16:50 -07:00
Awni Hannun
9592766939
add std as method (#1387)
* add std as method

* add std as method
2024-09-01 19:49:16 -07:00
Awni Hannun
dba2bd1105
Even Even Faster IO (#1374)
* even more faster io

* make reader pool static

* make python reader thread safe

* one more optimization
2024-08-29 16:05:40 -07:00
Awni Hannun
fcb65a3897
Even Faster I/O (#1369)
* try multithreading for faster IO

* smaller batch size

* Account for pread returning less than size

* nit

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-08-28 11:49:07 -07:00
Awni Hannun
291cf40aca
Some fixes to typing (#1371)
* some fixes to typing

* fix module reference

* comment
2024-08-28 11:16:19 -07:00
Aditya Dhulipala
e6b223df5f
Pinv (#875) 2024-08-27 23:06:12 -07:00
Angelos Katharopoulos
cdb59faea6
Adds send/recv ops in distributed (#1366) 2024-08-26 23:01:37 -07:00
Alex Barron
1d94ac3f90
Add optional headers to `mx.fast.metal_kernel` (#1358) 2024-08-26 21:45:45 -07:00
Awni Hannun
5f7d19d1f5
MPI ops in GPU stream for faster comms (#1356) 2024-08-26 15:12:50 -07:00
Alex Barron
d1183821a7
int() and float() for mx.array (#1360) 2024-08-25 20:41:44 -07:00
Alex Barron
b96e105244
Add grid_sample example to metal_kernel docs (#1352)
* Add `zero_outputs` and `atomic_outputs` options to `metal_kernel`

* add grid sample to docs

* zero_outputs -> init_value

* add missing header for linux
2024-08-23 18:24:16 -07:00
Awni Hannun
3b4d5484c7
Bump extension MLX version (#1350)
* Bump extension MLX version

* fix some docs nits
2024-08-23 12:38:34 -07:00
Alex Barron
0fd2a1f4b0
Custom Metal Kernels from Python (#1325)
* start

* simple kernels working

* restructure

* inverse example working

* docs + fixes

* missing file

* fix imports

* address comments

* add docs + fix test

* Review comments + refactor to a single function

* update docs

* remove hashing

* fix contig bug in test

* back to a class

* trailing whitespace

* fix tests

* match c++ and python apis

* add link + make args kw_only
2024-08-22 13:46:29 -07:00
Awni Hannun
bb1b76d9dc
RoPE with frequencies as optional input (#1337)
* start rope with freq input

* rope with frequencies

* nits

* fix bug

* fix bug + test

* cleanup

* optional base
2024-08-19 18:30:50 -07:00
Awni Hannun
eaaea02010
Add isfinite (#1318)
* isfinite

* remove reduce test since fix is not complete
2024-08-13 14:49:28 -07:00
Brian Keene
19fb69e2ed
Add memory_efficient_threshold kwarg to sdpa kernel (#1319)
Allows opt-in to memory efficient GPU shader at proscribed sequence
length.  Otherwise, utilizes aggregate MLX primitives for best latency.
2024-08-12 12:57:09 -07:00
Awni Hannun
9231617eb3
Move to nanobind v2 (#1316) 2024-08-08 17:17:46 -07:00
Alex Barron
32668a7317
CPU mx.linalg.cholesky_inverse and mx.linalg.tri_inv (#1307)
* add cholesky inv + tri inv

* always run tri_inv on cpu

* consistent naming
2024-08-08 15:18:02 -07:00
Alex Barron
635ccd9e25
Add "edge" mode to mx.pad (#1309)
* Add edge padding mode

* fix pad in pooling

* string arg instead of enum
2024-08-06 11:23:10 -07:00
Awni Hannun
10b5835501
fix creating array from bf16 tensors in jax / torch (#1305) 2024-08-01 16:20:51 -07:00
Awni Hannun
40b6d67333
Fixes for large arrays with a few ops (#1299)
* fixes for large arrays with a few ops

* fix bug

* fix all of copy
2024-07-30 17:18:39 -07:00
Alex Barron
c52d1600f0
Fused Affine Quantize/Dequantize ops (#1282)
* Add fast affine dequantize

* add full quantize kernel

* fused kernel with scale/bias computation

* fix docstring

* fix no jit error

* fix test

* test fix

* reduce fast api to only affine_quantize
2024-07-29 15:11:38 -07:00
Awni Hannun
aa1d6cadad
Fix docs latex build and nits (#1297)
* fix docs latex build and nits

* fix stub gen and try to clean up building
2024-07-29 11:44:06 -07:00
Awni Hannun
7b456fd2c0
Array api (#1289)
* some updates for numpy 2.0 and array api

* some updates for numpy 2.0 and array api

* fix array api doc
2024-07-26 10:40:49 -07:00
Anton Belov
5029894662
[Issue #1187] Add nan_to_num function initial attempt (#1247)
* initial attempt, working with wrong types

* not compiling; mx.float16 and mx.bfloat16 tests added

* fix nan to num

* nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-07-25 09:57:37 -07:00
Awni Hannun
baf9fa5f42
Einsum (#1269)
* einsum initial

* fix comma break

* sum axis was wrong

* small cleanups

* python binding

* changed bindings to resemble numpy

* remove todo comment

* comment changes

* add count of operands/inputs

* fail fast if operands list is empty

* ignore comma if no output

* einsum path matching numpy

* getting somewhere with path

* remove print

* it passes the first test

* moved einsum tests to seperate file

* seperated einsum path

* moved einsum naive

* remove space from equation

* fast fail if no operands passed

* update tests and remove printf

* small cleanup

* some more cleanups

* removed python helper file

* ack

* utilize std for finding min in vector

* duplicate def

* remove the tuple as it was unreadable

* moved einsum_naive back to ops

* remaining isn't needed

* avoid creating another set

* cleanup

* greedy path, start of naive einsum

* more einsum

* fix some bugs

* some more fixes, tests pass

* benchmark

* some simplify

* fix einsum and test

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>

* add a bunch more tests and fix a bunch more bugs

* some docs nits

---------

Co-authored-by: dc-dc-dc <dgcruz983@gmail.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-07-25 09:36:44 -07:00
fgranqvist
50eff6a10a
Implement sampling from laplace distribution. (#1279) 2024-07-24 15:15:37 +02:00
Alex Barron
c34a5ae7f7
Fix bfloat16 Hadamard (#1283)
* fix bfloat16 hadamard

* add scale

* review comments

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-07-23 14:54:43 -07:00
Awni Hannun
e2aa6ec8ae
some fixes (#1281) 2024-07-23 11:49:05 -07:00
Awni Hannun
218047c75a
docs fixes (#1263) 2024-07-11 15:59:07 -07:00
Angelos Katharopoulos
5c1fa64fb0
Custom transforms (#1246) 2024-07-10 18:00:01 -07:00
Alex Barron
a3c287354f
Fast Hadamard Transform (#1249)
* Working hadamard for powers of 2

* working for m*2^k

* add scale and check contiguity

* add size check

* clean up

* fix test

* add grads + vmap

* gpu only

* skip on linux

* test typo

* add cpu impl

* remove gpu only tests

* fix linux build + add is_equivalent
2024-07-09 20:39:01 -07:00
Awni Hannun
20bb301195
CPU binary reduction + Nits (#1242)
* very minor nits

* reduce binary

* fix test
2024-06-28 13:50:42 -07:00
Angelos Katharopoulos
b05bcfd27f
Fixes segfault when compiling checkpointed functions (#1235) 2024-06-26 16:14:45 -07:00
David Koski
4eef1e8a3e
fix typo (#1215) 2024-06-24 13:36:35 -07:00
Alex Barron
95d11bda06
Fix NumPy 2.0 pickle test (#1221)
* fix numpy version <2 temporarily

* typo

* better fix

* Fix just for bfloat16

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-23 05:47:22 -07:00
Angelos Katharopoulos
0163a8e57a
Add docs for the distributed namespace (#1184) 2024-06-06 11:37:00 -07:00
Awni Hannun
ea9090bbc4
Add view op (#1179)
* add view primitive

* nit

* fix view
2024-06-04 08:05:27 -07:00
Angelos Katharopoulos
3de8ce3f3c
In place all-reduce and forgiving init (#1178) 2024-06-03 16:47:47 -07:00
Brian Keene
1865299a30
Metal shaders for memory efficient self attention on large sequences (#964)
* Metal shaders for efficient self attention on large sequences

Updated fast attention: GEMM-ified with Steel primitives
Uses flash attention 1 for scale correction

* more compiler silencing

* Address rebase issues

* Templatize kernel instantiation, revise cpu bindings

* Safer writes to output

* Permit batch size > 1

* Numerical fixes for sdpa self attention

* Re-enable test, remove unused variable

* add benchmarking script

* Disable sdpa prior to perf tuning, and simplify tests for per-patch CI
2024-06-03 09:16:19 -07:00
K Venkat Ramnan
ab977109db
feat: Added dlpack device (#1165)
* feat: Added dlpack device

* feat: Added device_id to dlpack device

* feat: Added device_id to dlpack device

* doc: updated conversion docs

* doc: updated numpy.rst dlpack information

* doc: updated numpy.rst dlpack information

* Update docs/src/usage/numpy.rst

* Update docs/src/usage/numpy.rst

---------

Co-authored-by: Venkat Ramnan Kalyanakumar <venkatramnankalyanakumar@Venkats-MacBook-Air.local>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-05-31 12:29:01 -07:00
Awni Hannun
7e26fd8032
Option to JIT steel gemm / conv (#1139) 2024-05-23 18:07:34 -07:00
Angelos Katharopoulos
50dfb664db
Comms (#1097)
* Start the communications branch using MPI
* Add ops and primitives
* Add python bindings for distributed
2024-05-23 17:04:02 -07:00
Awni Hannun
eb8321d863
list based indexing (#1150) 2024-05-22 15:52:05 -07:00
Abe Leininger
79ef49b2c2
add mx.trace (#1143) (#1147)
* working c++ trace implementation

* updated throw + added overloads

* added python binding for trace function

* pre-commit reformatting

* add trace to docs

* resolve comments

* remove to_stream call
2024-05-22 15:50:27 -07:00
Awni Hannun
d568c7ee36
Rename block sparse (#1149)
* block_sparse_mm to gather_mm

* rename

* nit

* nit
2024-05-22 07:48:34 -07:00
Awni Hannun
e6fecbb3e1
Some fixes in docs (#1141)
* fixes in docs

* nit
2024-05-20 11:51:47 -07:00
Luca Arnaboldi
b3ec792380
Implemented Cholesky on CPU (#1119) 2024-05-17 12:31:59 -07:00
Awni Hannun
81dd33af66
allow conversion to dlpack (#1120) 2024-05-16 16:11:37 -07:00
Angelos Katharopoulos
e78a6518fa
Block sparse qmm (#1124) 2024-05-16 15:24:14 -07:00
Jacket
c417e42116
[Fix] minor typo in default argument for argpartition's "axis" parameter (#1125)
According to the document, argpartition's axis parameter can be None, but due to a previous typo it can't really accepts a None value.
2024-05-15 15:25:25 -07:00
Awni Hannun
631dfbe673
fix scatter index bug (#1122) 2024-05-14 15:04:58 -07:00
Cheng
56a4eaed72
Pass missing stream arg in array.flatten (#1111) 2024-05-14 06:50:16 -07:00
Cheng
bf925d9dc7
Move args in conv_general (#1118)
Also fix a typo that padding_lo is passed as padding_hi.
2024-05-14 06:50:09 -07:00
Cheng
1a7ed5dcb6
Fill vector with constructor instead of fill_n (#1113) 2024-05-14 06:28:55 -07:00
Cheng
cbd5445ea7
The tile op does not accept None as reps (#1117) 2024-05-14 06:25:25 -07:00
Max-Heinrich Laves
ff4223904d
Conv3d (#993)
* added conv3d

added conv3d

implemented explicit_gemm_conv_ND_cpu and bounds checks for slow_conv_3D

* incorporated reviewer comments

* fixed test

* reduced tensor shapes in test for conv3d

* Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion
2024-05-11 06:15:02 -07:00
Alex Barron
2e158cf6d0
Add conjugate operator (#1100)
* cpu and gpu impl

* add mx.conj and array.conj()

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-05-10 07:22:20 -07:00
Awni Hannun
b21242faf1
Allow unary ops to accept array like (#1093) 2024-05-09 09:36:02 -07:00
Rahul Yedida
cc05a281c4
Added ArcTan2 operation (#1079)
* Added ArcTan2 operation

* Cleanup, bug fixes from code review

* Minor cleanup, fixed Linux tests
2024-05-08 08:35:15 -07:00
Awni Hannun
9814a2ae12
fix conversion to array (#1070) 2024-05-06 16:02:49 -07:00
Shubham
6992498e7a
add keyword positonal (#1081) 2024-05-06 07:18:49 -07:00
Awni Hannun
21623156a3
Reset peak memory (#1074)
* reset peak memory

* fix linux

* nits in docs
2024-05-03 17:12:51 -07:00
Awni Hannun
b00ac960b4
change initial memory limits and add memory size to device info (#1064) 2024-05-03 06:50:15 -07:00
Jagrit Digani
f390957685
Block sparse mm (#1058) 2024-05-02 14:03:58 -07:00
Awni Hannun
19bef39f5c
Add a mx.metal.device_info (#1060)
* device inof

* add variant

* fix linux

* fix doc
2024-04-30 15:47:27 -07:00
Awni Hannun
09f1777896
fix slice update indexing (#1053) 2024-04-29 12:17:40 -07:00
Jacket
490c0c4fdc
[Fix] expand axes for dimension with integer indices in mlx_slice_update (#1035)
* Not sure if this is correct

* Format

* Edit tests

* Add negative test

* Format

* add one more test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-29 07:57:28 -07:00
Awni Hannun
86f495985b
Add bitwise ops (#1037)
* bitwise ops

* fix tests
2024-04-26 22:03:42 -07:00
Awni Hannun
5bfe89bdb1
Cpp docs (#1036)
* start of C++ docs

* fix stream doc

* only include ops for now
2024-04-26 12:56:05 -07:00
Awni Hannun
771575d27b
Expose function to clear memory cache (#1032)
* expose function to clear memory cache

* fix linux build

* fix metal tests
2024-04-24 16:48:51 -07:00
Aneesh Shetty
d0dbfe0b97
Adds radians and degrees (#1011) 2024-04-22 11:17:49 -07:00
Awni Hannun
3d405fb3b1
Add synchronize function (#1006)
* add synchronize function

* fix linux

* fix linux

* fix and fix docs

* fix test

* try synchronize in stream destroy

* synchronize works for both cpu and gpu
2024-04-22 08:25:46 -07:00
Angelos Katharopoulos
ef5f7d1aea
Fix buffer protocol buffer size designation (#1010) 2024-04-19 06:06:13 -07:00
Awni Hannun
8a0677d56d
Shared events for synchronization + async eval (#998)
* more async eval

* fix rebase

* try correct async eval

* fix async

* more tests for async eval

* use shared events for synchronization

* comment + cleanup

* with autorelease pool

* fix no metal build

* fix compile

* fix patch

* don't eval if asyn evale'd

* don't use is_evaled

* comments

* more multi stream tests

* try and cleanup use of is_evaled

* use a status flag
2024-04-17 06:16:02 -07:00
Jagrit Digani
b18468bf81
Masked mm (#978)
* Add block masked matmul op and primitive
2024-04-16 14:45:39 -07:00
Awni Hannun
12d4507ee3
Explicit barriers with concurrent dispatch (#977) 2024-04-10 21:45:31 -07:00
Awni Hannun
99abb9eff4
Async eval (#972) 2024-04-09 18:34:00 -07:00
Luca Arnaboldi
fffe072028
Implementation of mlx.random.multivariate_normal (#502) (#877)
* Implementation of mlx.random.multivariate_normal (#502)

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Updated typo in docstring

* Restricted multivariate_normal to  float32

* Generic mean and variance shapes

* Review edits

* Update mlx/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Test for ndim of mean and cov

* nits

* smaller size for test

* fix broadcasted sampling

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-09 13:50:12 -07:00
Abe Leininger
a1a31eed27
Add mx.meshgrid (#961) 2024-04-09 11:43:08 -07:00
Awni Hannun
42afe27e12
std and expm1 (#973)
* std and expm1

* actually add expm1

* fix linux

* fix vjp

* relax tol for linux test

* Add it to the compilable primitives

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-04-08 14:26:01 -07:00
Awni Hannun
aac2f9fb61
Improve profiling with gpu tracing (#969)
* improve profiling with gpu tracing

* fix for linux

* nit

* doc fix

* fix example
2024-04-07 21:47:43 -07:00
Awni Hannun
e142aaf8a1
Option for precise softmax (#953)
* precise softmax

* Add an equivalency check

* Make the threadgroup memory definition fixed

* precise cpu softmax

* precise option on cpu

* remove print

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-04-04 08:32:35 -07:00
AmirHossein_Razlighi
0caf35f4b8
Better exceptions in case of invalid operations on mlx.core.array (#910) (#926)
* Nicer exceptions for ops on non-arrays
2024-04-02 21:11:24 -07:00
Angelos Katharopoulos
3fc993f82d
Properly handle negative axes in python vmap (#944) 2024-04-02 18:07:23 -07:00
Jagrit Digani
639e06e1f3
Indexing bug fix (#947)
* Fix axes accounting

* Add tests
2024-04-01 12:18:50 -07:00
Angelos Katharopoulos
02fedbf1da
Fix array initialization from list (#942)
* Fix array initialization from list

* Change the error message in the test
2024-04-01 06:27:52 -07:00
AmirHossein_Razlighi
f48bc496c7
Comparing python objects (such as list/tuple) with mlx.core.array (#920)
* add implicit conversion of list to array for equality constraint

* add tests for array equality

* add test for tuple and array equality

* return False if __eq__ arg is list or tuple

* write tests for equality

* update the rule of comparison for __ge__/__gt__/__lt__/__le__

* add a helper function for detecting mlx.core.array

* return true in case fo inequality

* debug minor issue regarding detecting mlx array

* add tests for inequality comparisons

* add name for contribution

* reformat files using pre-commit

* update tests for float

* update tests for inequality

* raise exception in case of invalid comparisons

* use isinstance instead of string comparison

* replace "is_convirtable_to_array" with previous logic

* remove throwing exceptions for other operations

* just a comment

* minor changes for efficiency

* optimize a utils function

* change the function name

* Update ACKNOWLEDGMENTS.md

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-29 06:52:30 -07:00
Cheng
46caf0bef0
Remove unnecessary string copies (#891)
1. Use string_view instead of string when there is no need for copy.
2. Otherwise move string when possible.
2024-03-28 13:14:59 -07:00
Cheng
a7b404ff53
Use uintptr_t instead of size_t to store funtion id (#916)
Also does some small cleanup of the compile cache code.
2024-03-28 06:37:59 -07:00
Abdussamet Türker
5611e1a95e
Fix unsqueeze with None (#899)
* Fix unsqueeze with None

* Clean unnecessary files
2024-03-26 13:59:44 -07:00
Jack Mousseau
8e686764ac
Ensure shape dimensions are within supported integer range (#566) (#704)
* Ensure shape dimensions are within supported integer range (#566)

* fix build

* fix rebase bug

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-25 13:29:45 -07:00
Daniel Strobusch
479051ce1c
add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)
* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate

numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)).

Closes #285.

* nits in docs

* unify type category checking

* nits in docs

* nits in docs

* more docs nits

* fix callable type

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-25 12:32:59 -07:00
Awni Hannun
1e16331d9c
post nanobind docs fixes and some updates (#889)
* post nanobind docs fixes and some updates

* one more doc nit

* fix for stubs and latex
2024-03-24 15:03:27 -07:00
Awni Hannun
be98f4ab6b
Reduce a little overhead (#871)
* some small overhead improvements

* use result_type in rms_norm

* remove release force

* fix + use non-vector version

* revert compile change

* fix ops

* a little more overhead

* a little more cleanup and overhead
2024-03-22 17:29:36 -07:00
Jagrit Digani
8e5a5a1ccd
Set item bug fix (#879)
* set item shaping bug fix

* Add extra tests
2024-03-22 12:11:17 -07:00
Cheng
9663c22fe9
Do not store iostream in shared_ptr (#872)
There is no need to store iostream in shared_ptr, doing so adds the cost
of a heap allocation.
2024-03-22 06:54:45 -07:00
Awni Hannun
44390bd3d0
Bump (#869)
* bump

* fix none in a few ops
2024-03-21 13:56:56 -07:00
Angelos Katharopoulos
2225374060
Adds mx.fast.layer_norm (#870) 2024-03-21 13:55:51 -07:00
Awni Hannun
a54f06b16f
Fast RMS Norm (#862)
* fast rmsnorm

* no rms gpu

* kernel

* fix shared mem

* looped rms and donation in softmax

* Make the squaring in float32 to avoid underflow

* Fix the default StreamOrDevice for rope and rms_norm in fast

* nits

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-03-21 07:20:54 -07:00
Jagrit Digani
a5681ebc52
Update set item (#861)
* Update mlx_set_item to handle regular slices without expanding

* Refactor ellipsis handling

* Route mlx_set_item to slice_update where possible

* Update mlx_scatter_args_slice

* Don't route to gather if no array indices
2024-03-21 02:48:13 -07:00
Md. Rasel Mandol
db6796ac61
simple typo fille (#848) 2024-03-19 06:15:17 -07:00
Awni Hannun
9a8ee00246
Switch to nanobind (#839)
* mostly builds

* most tests pass

* fix circle build

* add back buffer protocol

* includes

* fix for py38

* limit to cpu device

* include

* fix stubs

* move signatures for docs

* stubgen + docs fix

* doc for compiled function, comments
2024-03-18 20:12:25 -07:00
nicolov
eaba55c9bf
Add matrix inversion primitive (#822) 2024-03-15 06:34:36 -07:00
nicolov
d0c544a868
Add SVD primitive (#809)
Add SVD op using Accelerate's LAPACK following
https://developer.apple.com/documentation/accelerate/
compressing_an_image_using_linear_algebra

Co-authored-by: Nicolo Valigi <nvaligi@apple.com>
2024-03-12 12:30:11 -07:00
Daniel Falbel
ffb19df3c0
Fix docstring for correctly rendering (#820) 2024-03-12 11:46:44 -07:00
Awni Hannun
28301807c2
Version bump and os error (#807) 2024-03-07 13:57:58 -08:00
Awni Hannun
b7588fd5d7
fix inplace to not make a shallow copy (#804) 2024-03-07 09:34:11 -08:00
Luca Arnaboldi
cbefd9129e
Implementation of pickle, copy and deepcopy for Python arrays (#300 & #367). (#713)
* Implemented pickling and copy for Python arrays(#300 & #367)

* Fixing typos

* Pickle with NumPy arrays

* Pickle: workaround for bfloat16

* Revert "Pickle: workaround for bfloat16"

This reverts commit 25afe6bc09.

* Added an error when pickling bfloat16

* Update python/tests/test_array.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/tests/test_array.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/array.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/array.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* clang-format applied

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-06 08:02:41 -08:00
Brian Keene
0787724c44
Fast Inference SDPA op (#735)
* Fast Inference SDPA op

Implements metal shaders for:

o = mx.fast_inference_sdpa(queries, keys, values, scale, mask)

Supports fp16, fp32 dtypes; assumes d_k = 128.

Generic op support / prompt encoding supported via mlx primitives.
Metal implementation is for the inference use case only.

Majority of performance benefits appears to results from GQA & reduced
bandwidth requirements; there is approximate performance parity for the
MHA use case (from some measurements on M3 Max).

* Flush shared memory to zero before unprotected reads for (scores @ values)

* Move to fast:: namespace, address reviewer comments

... also attempt to revert formatter auto-change for files not relevant
to this change

* Shared memory flush to top of kernel

* Resolve compiler warnings

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update docstring per PR feedback

* Softmax in higher precision, ...

* route to fallback for more use cases - batch size > 1, head_dim other
  than 128, etc.
* Address linux build failure
* Address other reviewer comments

* Remove extraneous eval_cpu function per review

---------

Co-authored-by: Atila Orhon <64497909+atiorh@users.noreply.github.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: atila <atiorh@icloud.com>
2024-03-04 21:06:11 -08:00
Awni Hannun
5121f028d9
nice tensordot for mlx c (#782) 2024-03-04 09:51:02 -08:00
Awni Hannun
bc06cb9ff6
Pickle + dtype fix for numpy conversion (#763)
* pickle + dtype fix for numpy conversion

* fix getattribute on Module base

* remove unused function

* fix tests

* add topk to ops

* fix doc
2024-03-02 06:09:29 -08:00
Awni Hannun
d5964a2710
bindings for memory info (#761)
* bindings for memory info

* update api

* keep cache low if requested

* fix default

* nit in ops error
2024-03-01 19:51:58 -08:00
Ikko Eltociear Ashimine
cf3eb87e52
Fix typo in transforms.cpp (#764)
occuring -> occurring
2024-02-29 22:23:46 -08:00
Jagrit Digani
776c3d226d
Convolution update (#651)
* Init steel conv and update Conv primitive

* Update slow CPU implementation to support flipping and input dilation winograd conv routing

Co-authored-by: Awni Hannun <awni@apple.com>
2024-02-28 20:11:16 -08:00
Awni Hannun
420ff2f331
Add back compiled function signatures and docstrings (#749)
* try to add back compiled function signatures and docstrings

* add indentation to docstring
2024-02-27 13:18:59 -08:00
Awni Hannun
fe1dabf272
Fix compile with non standard types (#745)
* refactor tree utils

* fix compile + tree code refactor

* Add an extra test

* add a few missing activations to docs

* hash structure

* Encode the full argument structure

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-26 19:28:53 -08:00
Hinrik Snær Guðmundsson
08226ab491
added atleast *args input support (#710)
* added atleast list(array) input support

* function overloading implemented

* Refactoring

* fixed formatting

* removed pos_only
2024-02-26 11:17:59 -08:00
Awni Hannun
5798256fcf
Shapeless compilation for some graphs (#687)
* shapeless compilation for some graphs

* update compile benchmark

* default compile a few activations

* buffer donation

* bugfix

* shapeless fix

* update tests to work for cpu and gpu fusion

* test kwargs

* add kwargs to compile

* Recompile when python arguments change

* no compile for tanh

* some constant tests

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-19 21:43:54 -08:00
Awni Hannun
d0fda82595
fix tolist for half types (#702) 2024-02-19 09:44:27 -08:00
Hinrik Snær Guðmundsson
f883fcede0
Added support for atleast_1d, atleast_2d, atleast_3d (#694) 2024-02-19 09:40:52 -08:00
Srimukh Sripada
818cda16bc
Support LR schedulers (#334)
* Add a few LR schedulers

* Move parents's constructor call to the top

* Fix docstring

* refactor optimizers into two files

* add docs

* nit

* Fix Callable type annotation for python 3.8

---------

Co-authored-by: Awni Hannun <awni@apple.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-15 11:26:20 -08:00
Diogo
35431a4ac8
Adds device context manager (#679) 2024-02-14 14:14:58 -08:00
Awni Hannun
ccf1645995
Custom primitive + RoPE fat op (#676)
* extensions start

* rope custom op

* fix build

* docs + rope benchmark

* fix test

* Add a Metal kernel for RoPE

* Fix position of traditional

* transform tests

* Move rope computation to float and fix tests

* Fix the test and a typo

* change to fast

* fix no metal build

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-14 14:04:25 -08:00
Diogo
b57bd0488d
Metadata support for safetensors (#639)
* metadata support for safetensors

* aliases making it alittle more readable

* addressing comments

* python binding tests
2024-02-08 19:33:15 -08:00
Awni Hannun
5c03efaf29
Compile docs (#653)
* compile docs

* docs nits + comments
2024-02-08 11:21:50 -08:00
Awni Hannun
1b97b2958b
Compile with capture (#629)
* Simple kernel generation

* Remove the generate kernel from graph_utils

* fix multi-output with compile

* fuse with stopgrad

* v1 input, output capture in compile

* cleanup tree update with visitor update

* nit

* remove todo

* state for model, optional explicit init and more pure optimizer steps

* move learning rate to state

* add lr to opt state, some fixes in capture

* fix optim

* update tuple of containers as well

* fix stream for compiled output

* rng state for compile

* nit

* updates and comments

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-07 17:29:22 -08:00
Noah Farr
5fd11c347d
Add loc and scale to random.normal (#638)
* Add loc and scale to random.normal

* Add tests for loc and scale for random.normal

* Run pre-commit hooks

* Fix code review
2024-02-07 11:49:59 -08:00
Awni Hannun
d75ae52ecd
Compile primitive (#571)
* Compiled primitive with basic binary, unary graph-level fusion
2024-02-05 06:51:22 -08:00
Awni Hannun
5c3ac52dd7
fix test (#627) 2024-02-04 16:18:03 -08:00
Daniel Strobusch
4fd2fb84a6
make python array SupportsAbs conform (like numpy) (#624) 2024-02-04 09:31:02 -08:00
Daniel Strobusch
9852af1a19
fix "shape" docstring. (#623) 2024-02-04 09:21:22 -08:00
Angelos Katharopoulos
0de5988f92
Custom VJP and checkpointing (#541)
* Implement custom_vjp and checkpointing
* Add a dependency management primitive
* Change the eval order to deep branches first
* Add graph depth tracking to the array
2024-01-30 16:04:45 -08:00
Awni Hannun
09b9275027
Make shape a tuple (#591)
* shape tuple

* also remove simplify from docs

* rebase
2024-01-30 13:11:01 -08:00
Jacket
3f7aba8498
Implement diagonal operator (#562)
* Implement diagonal operator

This implements mx.diagonal in operator level, inspired by
@ManishAradwad.

* added `mx.diag` with tests

* corrected few things

* nits in bindings

* updates to diag

---------

Co-authored-by: ManishAradwad <manisharadwad@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-30 09:45:48 -08:00
Angelos Katharopoulos
37d98ba6ff
No gil eval (#565) 2024-01-26 22:03:52 -08:00
Awni Hannun
07f35c9d8a
Fix a few issues: docs for flatten, erf, dequantize validation (#560)
* doc flatten

* erf doc

* check values for dequantize

* format
2024-01-26 15:16:46 -08:00
Jagrit Digani
bf17ab5002
Add more checks and clearer error messages to conv operations (#563)
* Add more checks and clearer error messages to conv operations
2024-01-26 15:13:26 -08:00
Awni Hannun
8fa6b322b9
Compile front-end (#476)
* fix tests for linux

* make a move on compile

* basic compile scaffold works

* compile binding

* clean

* fix

* fix grad, more tests

* basic python tests

* fix segfault on python exit

* compile works with python closures

* fix test

* fix python globals bug, and erase

* simplify

* more cpp tests

* bug fix with move function and compile at exit

* simplify inputs also

* enable and disable compiler

* remove simplify

* simplify tests use compile now

* fix multi-output with compile

* clear output tree from cache when function goes out of scope

* ../python/src/transforms.cpp

* remove closure capture

* comments
2024-01-26 13:45:30 -08:00
taher
077c1ee64a
QR factorization (#310)
* add qr factorization

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-26 09:27:31 -08:00
Rifur13
2463496471
[Fix] mx.allclose bug with infinite values (#539)
* Added isclose op and fixed comparison with inf values

* Added 'equal_nan' to match numpy

* format

* Add test

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Addressed CR comments

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* nits

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-25 20:47:06 -08:00
Awni Hannun
f30e63353a
Minor updates to address a few issues (#537)
* docs on arg indices return type

* arange with nan

* undo isort
2024-01-23 22:24:41 -08:00
Awni Hannun
98c37d3a22
use axes in tensordot (#525) 2024-01-22 21:17:00 -08:00
Awni Hannun
6bf779e72b
fix array from list for > 32 bit types (#501) 2024-01-19 15:49:25 -08:00
Juarez Bochi
ddf50113c5
GGUF: Load and save metadata (#446)
* gguf metadata
---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-19 14:06:05 -08:00
Ethan
a749a91c75
Support disable metal buffer cache to prevent performance degradation caused by large memory caching (#390)
* support disable metal buffer cache, due to large unused memory buffered when llm generated long context tokens

* Run format and add "cache_enabled" feature tests
2024-01-18 08:33:34 -08:00
toji
49a52610b7
Added formatter structure and a boolean value formatter (#354)
* added formatter structure and a boolean value formatter

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-18 07:49:41 -08:00
Angelos Katharopoulos
9c111f176d
Fix split optimization for array iterator (#484) 2024-01-18 05:50:25 -08:00