Commit Graph

323 Commits

Author SHA1 Message Date
Awni Hannun
a2ffea683a
Fix eye for larger matrices (#463)
* fix eye
* fix scatter for <32bit (non native atomic) types
* fix int overflow
2024-01-16 00:51:24 -08:00
Angelos Katharopoulos
c15fe3e61b
Allow arbitrary first dimension in quantization kernels. (#458)
* Allow arbitrary first dim on qmm_t and qmv
* Allow arbitrary first dim on qmm and qvm
* Specialized aligned vs unaligned case
* Add more checks for valid quantizations
2024-01-16 00:46:21 -08:00
Tristan Bilot
f44c132f4a
Add scatter_min VJP (#462) 2024-01-16 00:37:40 -08:00
Matthew Ernst
92a2fdd577
Adds isinf (#445)
* adds isinf

Signed-off-by: matthewfernst <matthew.f.ernst@gmail.com>

* use stream + nits

* typo

---------

Signed-off-by: matthewfernst <matthew.f.ernst@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-15 19:50:44 -08:00
Tristan Bilot
6022d4129e
scatter_max vjp + bindings + tests (#431)
Co-authored-by: DjamelMesbah <djamel.mesbah@adservio.fr>
2024-01-14 14:12:15 -08:00
Awni Hannun
4bc446be08
Use a dummy primitive to only sync with one output (#453)
* Use a dummy primitive to only sync with one output
* Fix test and choose stream with slight care
2024-01-14 14:09:40 -08:00
Awni Hannun
41cc7bdfdb
Fix stub generation, change graph exporting for arrows to go to outputs (#455) 2024-01-14 14:06:16 -08:00
Awni Hannun
6e81c3e164
Sync only with outputs we need to sync with (#447) 2024-01-13 01:47:25 -08:00
Diogo
2e29d0815b
Add tile op (#438) 2024-01-12 23:03:16 -08:00
Ayush Shridhar
1416e7b664
Add isnan (#423) 2024-01-12 11:16:48 -08:00
Angelos Katharopoulos
006d01ba42
Fix packaging of gguflib (#435) 2024-01-11 13:56:03 -08:00
Awni Hannun
c9934fe8a4
Metal validation (#432)
* tests clear metal validation

* add cpp test with metal validation to circleci

* nit
2024-01-11 11:57:24 -08:00
Awni Hannun
3b4f066dac
Correct types for vjp + tests (#418)
* correct types for vjp + tests

* fix build + comment
2024-01-10 13:32:37 -08:00
Juarez Bochi
b7f905787e
GGUF support (#350)
* Initial GGUF support for tensor fields.

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-10 13:22:48 -08:00
Angelos Katharopoulos
961435a243
Scatter vjp (#394)
* Add a first scatter vjp
* Implement the scatter_add vjp
* Add array.at to implement user friendly scatters
2024-01-09 13:36:51 -08:00
Awni Hannun
f099ebe535
Multi output primitives (#330)
* Multi-output primitives

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-01-08 16:39:08 -08:00
Nripesh Niketan
73321b8097
feat: add logicalAnd and logicalOR (#386)
* feat: add logicalAnd and logicalOR

* run pre-commit

* Refactor logical_and and logical_or functions

* Add acknowledgement

* Add logical AND and logical OR operators

* Refactor logical_and and logical_or functions

* Add support for logical operators on bool arrays

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Add logical AND and OR operators for arrays and scalars

* Refactor vjp and jvp methods in primitives.cpp

* Add overloaded operators for logical AND and OR

* format

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-08 07:00:05 -08:00
Angelos Katharopoulos
a611b0bc82
Removes the retain_graph flag (#385)
* Adds global tracing flag
* Removes retain_graph in favor of is_tracer
2024-01-07 15:16:51 -08:00
Diogo
449b43762e
Add inner / outer op (#348)
* inner / outer impl

* python tests

* ops list and ack

* updated descriptions

* use test helper

* removed dtype check and flatten outer to 1-D

* updated docs

* just use the reshape to flatten
2024-01-07 09:01:09 -08:00
Awni Hannun
c6d2878c1a
safely divide for 0 size inputs (#388) 2024-01-07 00:19:54 -08:00
Awni Hannun
b34bf5d52b
fix saving for non-contiguous arrays (#389) 2024-01-06 12:44:02 -08:00
Angelos Katharopoulos
608bd43604
Move the matmul type check in the op (#384) 2024-01-05 19:10:13 -08:00
mutexuan
d8f41a5c0f
support python mlx.array creation from list of mlx.array's (#325)
* support python mlx.array creation from list of mlx.array's

* include bfloat16 in UT

* refactor so that sub array made of all python primitive types gets initialized by fill_vector

* address PR comment: arr.shape().size() -> arr.ndim()

* address PR comment: get back Dtype constness and let stack to handle type promotions automatically
2024-01-04 18:53:33 -08:00
Awni Hannun
b9e415d19c
bump pre commit and fix format (#373) 2024-01-04 16:28:52 -08:00
davidkoski
c82a8cc526
move all ObjC (via metal-cpp) interaction until post static initializers (#370)
* move all ObjC (via metal-cpp) interaction until post static initializers

- metal-cpp relies on static initializers to cache class and selector pointers
- code in mlx was using metal-cpp to set up NSAutoreleasePools during its own static init time
- but this code was silently failing as the class and selector pointers from metal-cpp were still nil

- defer the creation of NSAutoreleasePools until after static init time
- ensure that we have coverage where autorelease pools are needed

* Update device.cpp

remove commented code

* Update device.cpp

remove commented out code

* Update scheduler.h

update comment

* per discussion use the pool inside the task() -- this will be metal only, not needed for cpu

* Update allocator.cpp

move pool to release/alloc area
2024-01-04 16:12:00 -08:00
Angelos Katharopoulos
e7f5059fe4
Support for quantized matmul with w and w^T (#349)
* Add the metal qvm implementation
* Add qmm_n
* Add gradient wrt to input for quantized_matmul
2024-01-03 14:22:36 -08:00
Diogo
0782a4573a
Add Tensordot op (#344) 2024-01-02 17:15:00 -08:00
Awni Hannun
99c80a2c8b
Memory allocation (#292)
* try alternative gc

* try no cache

* add forced swap

* remove cache for now

* add cache back

* change fit crtieria

* remove unused function

* nit in comment

* tune / fix allocation

* increase block limit to original
2024-01-02 11:59:19 -08:00
Josh Soref
44c1ce5e6a
Spelling (#342)
* spelling: accumulates

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: across

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: additional

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: against

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: among

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: array

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: at least

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: available

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: axes

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: basically

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: bfloat

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: bounds

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: broadcast

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: buffer

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: class

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: coefficients

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: collision

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: combinations

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: committing

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: computation

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: consider

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: constructing

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: conversions

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: correctly

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: corresponding

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: declaration

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: default

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: dependency

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: destination

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: destructor

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: dimensions

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: divided

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: element-wise

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: elements

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: endianness

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: equivalent

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: explicitly

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: github

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: indices

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: irregularly

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: memory

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: metallib

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: negative

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: notable

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: optional

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: otherwise

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: overridden

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: partially

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: partition

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: perform

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: perturbations

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: positively

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: primitive

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: repeat

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: repeats

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: respect

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: respectively

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: result

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: rounding

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: separate

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: skipping

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: structure

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: the

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: transpose

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: unnecessary

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: unneeded

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: unsupported

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

---------

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
2024-01-01 21:08:17 -08:00
Angelos Katharopoulos
a020a2d49d
Improve repeat using broadcasting and reshape (#318) 2023-12-29 21:40:20 -08:00
Bahaa
ff2b58e299
Add support for repeat (#278)
* add repeat function

* fix styling

* optimizing repeat

* fixed minor issues

* not sure why that folder is there xD

* fixed now for sure

* test repeat not repeat test

* Fixed

---------

Co-authored-by: Bahaa Eddin tabbakha <bahaa@Bahaas-MacBook-Pro.local>
2023-12-27 13:11:38 -08:00
Diogo
1f6ab6a556
Safetensor support (#215)
Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-27 02:06:55 -08:00
Gabrijel Boduljak
6b0d30bb85
linalg.norm (#187)
* implemented vector_norm in cpp

added linalg to mlx

* implemented vector_norm python binding

* renamed vector_norm to norm, implemented norm without provided ord

* completed the implementation of the norm

* added tests

* removed unused import in linalg.cpp

* updated python bindings

* added some tests for python bindings

* handling inf, -inf as numpy does, more extensive tests of compatibility with numpy

* added better docs and examples

* refactored mlx.linalg.norm bindings

* reused existing util for implementation of linalg.norm

* more tests

* fixed a bug with no ord and axis provided

* removed unused imports

* some style and API consistency updates to linalg norm

* remove unused includes

* fix python tests

* fixed a bug with frobenius norm of a complex-valued matrix

* complex for vector too

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-26 19:42:04 -08:00
Angelos Katharopoulos
9e6b8c9f48
Refactor the reduction kernels (#277) 2023-12-24 14:47:57 -08:00
Daniel Strobusch
7365d142a3
random.uniform must respect dtype, even if lower precision than "low" (#280)
Fix an edge case where random uniform returns a float32 array, even if a lower precision dtype is wanted due to adding the float32 "low" array.
2023-12-24 07:04:43 -08:00
Awni Hannun
8b227fa9af
fix no metal build (#276) 2023-12-23 19:18:10 -08:00
Ronan Collobert
cd3616a463
Revisit autorelease memory pools (#260)
* make general autorelease pool part of metal device

* make things simpler

* no metal backend support

* new_memory_pool -> new_scoped_memory_pool
2023-12-22 11:01:26 -08:00
Awni Hannun
2118c3dbfa
fix (#255) 2023-12-21 18:18:41 -08:00
Awni Hannun
a002797d52
A temporary fix (#254) 2023-12-21 17:59:15 -08:00
Daniel Strobusch
794feb83df
support arange for bfloat16 (#245) 2023-12-21 14:33:43 -08:00
Angelos Katharopoulos
b3916cbf2b
Improve names of quantization arguments (#235)
* Change the default quantization group_size to 64
* Rename groups to group_size and width to bits
2023-12-20 16:53:53 -08:00
Angelos Katharopoulos
57fe918cf8
Adds C++ and nn quantization utilities (#230)
* Add C++ de-/quantize ops
* Add quantize functions to the docs and tests
* Add a QuantizedLinear module
2023-12-20 14:17:38 -08:00
Angelos Katharopoulos
2807c6aff0
Implements divide for integer types and adds floor_divide op (#228)
* Add floor_divide
* Add floor_divide to the tests
* Add floor_divide to the docs
2023-12-19 20:12:19 -08:00
davidkoski
de892cb66c
fix for non-macos build issue on cblas.h (#227) 2023-12-19 17:01:59 -08:00
davidkoski
37024d899c
fixes for building with swiftpm (#225)
- clbas is part of veclib (compile failure)
- add SWIFTPM_BUNDLE #define to allow loading the metallib from a swiftpm resource bundle
2023-12-19 16:22:10 -08:00
Angelos Katharopoulos
dfa9f4bc58
An initial quantized matmul implementation (#205)
* Add quantized matvec
* Add quantized matrix matrix with 2nd matrix transposed
* Add quantized matmul tests
* Add a slow cpu quantized matmul
* Add a slightly faster vectorized cpu version
2023-12-18 23:18:57 -08:00
Abe Leininger
e6872a4149
Added linspace (#181)
* linspace ops support

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-18 19:57:55 -08:00
Angelos Katharopoulos
4d4af12c6f
Adds round op and primitive (#203) 2023-12-18 11:32:48 -08:00
Awni Hannun
0e5807bbcb
include optional (#202) 2023-12-17 22:01:35 -08:00
Cyril Zakka, MD
8eb56beb3a
Added clip function (#159)
* Added clip

* Added Python bindings

* Formatting

* Added cpp tests

* Added Python tests

* python bindings work

* rebase

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-17 20:00:29 -08:00
Awni Hannun
90d04072b7
fix build w/ flatten (#195) 2023-12-17 11:58:45 -08:00
__mo_san__
52e1589a52
implemented Flatten Module (#149)
* implemented flatten op

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-16 21:54:37 -08:00
Diogo
dc2edc762c
added tri / tril / triu (#170)
* added tri / tril / triu

* fixed tests

* ctest tests

* tri overload and simplified tests

* changes from comment

* more tests for m

* ensure assert if not 2-D

* remove broadcast_to

* minor tweaks

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-15 17:30:34 -08:00
Awni Hannun
2e02acdc83
add base kwarg to rope (#186) 2023-12-15 16:47:59 -08:00
Ronan Collobert
83f266c44c
Lazy metal_device_ initialization (#185)
This ensures it is defined when the Scheduler needs it.
2023-12-15 16:06:46 -08:00
Víctor Aguilar
f24200db2c
accross -> across (#183) 2023-12-15 13:46:50 -08:00
Jason
e28b57e371
Added mx.stack c++ frontend impl (#123)
* stack C++ operation + python bindings
2023-12-14 13:21:19 -08:00
Awni Hannun
e5851e52b1
Add move and swap axis, and vmap for slice, concat, and gather (#158)
* add move and swap axis, and vmap for slice, concat, and gather
2023-12-14 12:59:12 -08:00
Luca Arnaboldi
b93c4cf378
Floor and Ceil (#150)
* Implements Floor and Ceil Ops
2023-12-14 10:00:23 -08:00
Ikko Eltociear Ashimine
c3272d4917
Update conv.cpp (#145)
Peform -> Perform
2023-12-12 11:27:49 -08:00
Cyril Zakka, MD
e080290ba4
Added eye/identity ops (#119)
`eye` and `identity` C++ and Python ops
2023-12-11 12:38:17 -08:00
Awni Hannun
71d1fff90a
Bug fix in metal binary kernel dispatch for large arrays (#125)
* bug fix

* format
2023-12-10 16:12:31 -08:00
Angelos Katharopoulos
600db7d754
Fix build on Xcode 14 (#116)
* Fix build on Xcode 14

* Style fixes
2023-12-10 06:58:52 -08:00
Angelos Katharopoulos
2b714714e1
Add the remainder op (#85)
* Add remainder in the C++ backend
* Add the python binding and test
2023-12-08 15:08:52 -08:00
Angelos Katharopoulos
209404239b
Fix the accelerate dispatch for the power op (#70)
- The exponent and base were swapped because accelerate is using
  exponent-base instead of base-exponent
- Fix also the test for binary ops as it was testing op(x, x) which
  couldn't catch ordering errors like that
2023-12-08 10:58:03 -08:00
Awni Hannun
4e3bdb560c
random generation fix (#80)
Random generation fix
2023-12-08 10:40:57 -08:00
Jagrit Digani
d518b3b6a5
Fix gemv broadcasting bug (#6)
* Fix broadcasting bug in gemv
* Add relevant tests in test_blas.py
2023-12-05 14:15:43 -08:00
Awni Hannun
db487e6b1a format 2023-11-30 11:50:50 -08:00
Awni Hannun
46a39e5b1f copyright + ack 2023-11-30 11:12:53 -08:00
Awni Hannun
c1b6bf3f33 missing file 2023-11-29 12:38:32 -08:00
Jagrit Digani
e6306cfee9 jagrit's commit files 2023-11-29 10:52:08 -08:00
Angelos Katharopoulos
d1f86272a2 angelos's commit files 2023-11-29 10:42:59 -08:00
Awni Hannun
8ca7f9e8e9 awni's commit files 2023-11-29 10:30:41 -08:00