Commit Graph

143 Commits

Author SHA1 Message Date
AtomicVar
d1fef34138
Add Gaussian NLL loss function (#477)
* Add Gaussian NLL loss function

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-18 06:44:44 -08:00
Angelos Katharopoulos
9c111f176d
Fix split optimization for array iterator (#484) 2024-01-18 05:50:25 -08:00
Angelos Katharopoulos
90c234b7ac
Fix round to round half-cases to even (#482) 2024-01-17 15:27:23 -08:00
Jagrit Digani
78102a47ad
Update GEMM (#424)
* Organize and collect metal subroutine templates and elements in `metal/kernels/steel/`
* Update gemm elements for better performance 
* Add split-K specialization for gemm
* Add `addmm` primitive, op and bindings for fused matmul and bias addition 
* Update tests and benchmarks as needed
2024-01-17 12:42:39 -08:00
Awni Hannun
a2bf7693dd
Primitive's VJP takes outputs as input (#475)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-01-16 19:03:53 -08:00
Angelos Katharopoulos
d8fabaa12b
Split multi output (#461)
* Multi-output split primitive
* Add the multi-output split to the ArrayIterator
* Add some grad tests for split
2024-01-16 13:33:55 -08:00
Yashraj Singh
e72458a3fa
implemented isposinf and isneginf in one PR (#470)
* ran precommit

* updated docs
2024-01-16 06:48:07 -08:00
Angelos Katharopoulos
c15fe3e61b
Allow arbitrary first dimension in quantization kernels. (#458)
* Allow arbitrary first dim on qmm_t and qmv
* Allow arbitrary first dim on qmm and qvm
* Specialized aligned vs unaligned case
* Add more checks for valid quantizations
2024-01-16 00:46:21 -08:00
Tristan Bilot
f44c132f4a
Add scatter_min VJP (#462) 2024-01-16 00:37:40 -08:00
Matthew Ernst
92a2fdd577
Adds isinf (#445)
* adds isinf

Signed-off-by: matthewfernst <matthew.f.ernst@gmail.com>

* use stream + nits

* typo

---------

Signed-off-by: matthewfernst <matthew.f.ernst@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-15 19:50:44 -08:00
Tristan Bilot
6022d4129e
scatter_max vjp + bindings + tests (#431)
Co-authored-by: DjamelMesbah <djamel.mesbah@adservio.fr>
2024-01-14 14:12:15 -08:00
Awni Hannun
41cc7bdfdb
Fix stub generation, change graph exporting for arrows to go to outputs (#455) 2024-01-14 14:06:16 -08:00
Diogo
2e29d0815b
Add tile op (#438) 2024-01-12 23:03:16 -08:00
Awni Hannun
1b71487e1f
docs (#444) 2024-01-12 13:34:16 -08:00
Ayush Shridhar
1416e7b664
Add isnan (#423) 2024-01-12 11:16:48 -08:00
davidkoski
29081204d1
array.swapaxes should point to swapaxes free function (#441) 2024-01-12 11:06:16 -08:00
Avikant Srivastava
975e265f74
feat: Add numpy constants (#428)
* add numpy constants

* feat: add unittests

* add newaxis

* add test for newaxis transformation

* refactor
2024-01-11 06:47:29 -08:00
Awni Hannun
3b4f066dac
Correct types for vjp + tests (#418)
* correct types for vjp + tests

* fix build + comment
2024-01-10 13:32:37 -08:00
Juarez Bochi
b7f905787e
GGUF support (#350)
* Initial GGUF support for tensor fields.

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-10 13:22:48 -08:00
Chunyang Wen
e3e933c6bc
Add type hint for Module (#412) 2024-01-10 11:23:42 -08:00
Awni Hannun
1d90a76d63
in place ops behave in place, fix some overloads (#411) 2024-01-09 16:05:38 -08:00
Angelos Katharopoulos
961435a243
Scatter vjp (#394)
* Add a first scatter vjp
* Implement the scatter_add vjp
* Add array.at to implement user friendly scatters
2024-01-09 13:36:51 -08:00
Awni Hannun
e9ca65c939
Fix BN stats to not expand shape (#409)
* fix BN stats to not expand shape

* nit
2024-01-09 11:54:51 -08:00
Awni Hannun
f099ebe535
Multi output primitives (#330)
* Multi-output primitives

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-01-08 16:39:08 -08:00
YUN, Junwoo
0b8aeddac6
Additoinal losses (#336)
* cosine similarity loss

---------

Co-authored-by: Awni Hannun <awni@apple.com>

* Docstring nits
2024-01-08 14:01:13 -08:00
Nripesh Niketan
73321b8097
feat: add logicalAnd and logicalOR (#386)
* feat: add logicalAnd and logicalOR

* run pre-commit

* Refactor logical_and and logical_or functions

* Add acknowledgement

* Add logical AND and logical OR operators

* Refactor logical_and and logical_or functions

* Add support for logical operators on bool arrays

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Add logical AND and OR operators for arrays and scalars

* Refactor vjp and jvp methods in primitives.cpp

* Add overloaded operators for logical AND and OR

* format

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-08 07:00:05 -08:00
Hazem Essam
022a944367
Added GLU activation function and Gated activation function (#329)
* Added GLU activation function and gated activation function

* Ran pre-commit

* Ran pre commit

* Removed old sigmoid implementation to match with main

* Removed gated activation from __init__.py

* Removed unused test cases

* Removed unused imports

* format / docstring

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-08 06:13:16 -08:00
Angelos Katharopoulos
a611b0bc82
Removes the retain_graph flag (#385)
* Adds global tracing flag
* Removes retain_graph in favor of is_tracer
2024-01-07 15:16:51 -08:00
Diogo
449b43762e
Add inner / outer op (#348)
* inner / outer impl

* python tests

* ops list and ack

* updated descriptions

* use test helper

* removed dtype check and flatten outer to 1-D

* updated docs

* just use the reshape to flatten
2024-01-07 09:01:09 -08:00
Angelos Katharopoulos
6ea6b4258d
Fix style check (#395) 2024-01-07 05:54:58 -08:00
Anchen
48f6ca8c3a
Add theta cache for Rope and mask cache for ALiBi (#375) 2024-01-07 00:22:58 -08:00
Awni Hannun
b34bf5d52b
fix saving for non-contiguous arrays (#389) 2024-01-06 12:44:02 -08:00
Angelos Katharopoulos
4c48f6460d
Fix segfault from buffer protocol and tests (#383)
* Fix segfault from buffer protocol and tests

* Fix tf test
2024-01-05 18:17:44 -08:00
Daniel Strobusch
1331fa19f6
Make array conform to the Python Buffer Protocol (#323) 2024-01-05 15:58:33 -08:00
Daniel Strobusch
dfdb284e16
make behaviour of dtype arguments consistent and compliant to numpy (#379)
All functions that take an optional dtype should

* have a default dtype visible in the generated docs (accomplished via `"dtype"_a = std::optional{float32}`)
* behave identical when `dtype=None` or no dtype is passed

This important when passing kw args down from a numpy function like:

```
def f(x, dtype=None):
  mx.random.uniform(dtype=dtype)
  # ...
```

NumPy functions behave like this.

It also fixes a minor bug in `tri`: #378

Closes #378
2024-01-05 09:37:46 -08:00
mutexuan
d8f41a5c0f
support python mlx.array creation from list of mlx.array's (#325)
* support python mlx.array creation from list of mlx.array's

* include bfloat16 in UT

* refactor so that sub array made of all python primitive types gets initialized by fill_vector

* address PR comment: arr.shape().size() -> arr.ndim()

* address PR comment: get back Dtype constness and let stack to handle type promotions automatically
2024-01-04 18:53:33 -08:00
Awni Hannun
b9e415d19c
bump pre commit and fix format (#373) 2024-01-04 16:28:52 -08:00
Angelos Katharopoulos
75dc537e44
Fix the sigmoid module (#371) 2024-01-04 13:16:36 -08:00
Awni Hannun
cf88db44b5
revert copy (#366) 2024-01-04 10:43:29 -08:00
Chunyang Wen
16856a0160
Remove useless pass (#364)
Co-authored-by: Chunyang Wen <chunyang_wen@apple.com>
2024-01-04 06:34:01 -08:00
Awni Hannun
d752f8e142
Fix CI (#359)
* fix ci

* check for linux for fp16
2024-01-04 06:33:08 -08:00
toji
d2467c320d
Added support for python copy (#335)
* Added support for python copy

* precommit changes

* removed `_compiled_call_impl` line

* added tests and suggested changes

* ACK changes
2024-01-03 20:59:40 -08:00
Diogo
0d31128a44
use union instead of | (#358) 2024-01-03 19:33:19 -08:00
Diogo
1ac18eac20
simple numpy helper for tests (#352) 2024-01-03 19:19:19 -08:00
Angelos Katharopoulos
e7f5059fe4
Support for quantized matmul with w and w^T (#349)
* Add the metal qvm implementation
* Add qmm_n
* Add gradient wrt to input for quantized_matmul
2024-01-03 14:22:36 -08:00
Gabrijel Boduljak
c7edafb729
implemented InstanceNorm (#244)
* implemented instancenorm

* implemented vector_norm in cpp

added linalg to mlx

* implemented vector_norm python binding

* renamed vector_norm to norm, implemented norm without provided ord

* completed the implementation of the norm

* added tests

* removed unused import in linalg.cpp

* updated python bindings

* added some tests for python bindings

* handling inf, -inf as numpy does, more extensive tests of compatibility with numpy

* added better docs and examples

* refactored mlx.linalg.norm bindings

* reused existing util for implementation of linalg.norm

* more tests

* fixed a bug with no ord and axis provided

* removed unused imports

* some style and API consistency updates to linalg norm

* remove unused includes

* fix python tests

* fixed a bug with frobenius norm of a complex-valued matrix

* complex for vector too

* addressed PR review comments

* fixed import order in __init__

* expected values in instancenorm tests are simple lists

* minor return expression style change

* added InstanceNorm to docs

* doc string nits

* added myself to individual contributors

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-03 12:21:15 -08:00
Awni Hannun
dff4a3833f
Module checks the weight on load_weights (#337)
* update module to check weights on load, also fix docs and reorganize tests

* nits + rebase

* a few more docs updates for Module

* use manual module file

* comment
2024-01-02 18:55:42 -08:00
Diogo
0782a4573a
Add Tensordot op (#344) 2024-01-02 17:15:00 -08:00
Angelos Katharopoulos
436bec9fd9
Fix the implementation of the Bilinear layer (#347) 2024-01-02 16:46:18 -08:00
Asaf Zorea
295ce9db09
Feature expand nn linear (#315)
* Added an identity and bilinear layers
Added a reset_parameters option
Added normal init for bias

* pre-commit run

* add type hints for parameters and the return type
change Bilinear math to x_1 and x_2
change __call__ arguments to x and y instead of input and output
add explanation to the Initialization

* Remove unnecessary reshape

* Added 'i' to bilinear formula

* Changed bilinear computation to two matrix multiplications

* avoid saving intermediate results, kept y in bilinear for better clarity (can be replaced with x1)

* Changed math formula in Linear
Added more explanation to math formulas
Changed x1, x2 reshape to support all inputs sizes
2024-01-02 06:08:53 -08:00