Commit Graph

93 Commits

Author SHA1 Message Date
Awni Hannun
48ef3e74e2
reduce vjp for all and any (#2193) 2025-05-16 08:38:49 -07:00
Awni Hannun
602f43e3d1
fix conv grad (#2187) 2025-05-15 19:20:36 -07:00
Awni Hannun
c1eb9d05d9
non-symmetric eig and eigh (#2188) 2025-05-15 13:01:44 -07:00
Angelos Katharopoulos
cf6c939e86
Fix some complex vjps (#2178) 2025-05-14 23:37:12 -07:00
ATurker
a7fae8a176
fix: conv_general differences between gpu, cpu (#2070)
* fix general_conv padding

* fix bugs

* add test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-05-09 10:26:52 -07:00
Awni Hannun
7bb063bcb3
Enable vjp for quantized scale and bias (#2129)
* Enable vjp for quantized scale and bias

* higher tol
2025-04-29 13:03:09 -07:00
Awni Hannun
79b527f45f
conv vmap (#2102) 2025-04-21 13:04:39 -07:00
Angelos Katharopoulos
5de6d94a90
Gather qmm batched kernel and refactoring of quantized (#2078) 2025-04-17 13:53:11 -07:00
Angelos Katharopoulos
99eefd2ec0
Gather mm new kernel and small refactoring (#2040) 2025-04-14 16:37:36 -07:00
Yury Popov
e9e268336b
LogCumSumExp (#2069) 2025-04-13 01:27:29 -07:00
Awni Hannun
de5f38fd48
Custom logsumexp (#2028)
* initial custom logsumexp

* more tests

* comments + fix
2025-03-31 07:36:55 -07:00
Awni Hannun
32da94507a
fix vmap for flatten (#1955) 2025-03-11 10:42:22 -07:00
Abe Leininger
3835a428c5
Adds nuclear norm support (#1894)
* adjust norm unit test tolerance
2025-03-04 13:26:02 -08:00
Alex Barron
5cd97f7ffe
Bitwise Inverse (#1862)
* add bitwise inverse

* add vmap + fix nojit

* inverse -> invert

* add to compile + remove unused
2025-02-13 08:44:14 -08:00
Angelos Katharopoulos
9eb7d7362f
Fix Split::vmap (#1845) 2025-02-08 09:22:13 -08:00
Awni Hannun
af1b725fda
Fix a couple of slicing bugs (#1827)
* fix a few bugs

* fix conv grad

* speedup test

* comment
2025-02-05 19:50:08 -08:00
Awni Hannun
b7c9f1d38f
scatter axis + gather axis primitives (#1813)
* scatter axis + gather axis primitives

* add transforms

* comment
2025-01-31 20:48:08 -08:00
Awni Hannun
0c259961ac
matmul jvps (#1772) 2025-01-17 10:36:26 -08:00
Awni Hannun
1ccaf80575
Dynamic broadcasting for shapeless compile/export (#1722)
* working towards dynamic broadcast

* shapeless broadcast

* fix build + nits

* use broadcast arrays in quantize matmul

* some cleanup / consistency

* mend

* some comments

* add vjp, jvp for broadcast axes
2025-01-09 11:04:24 -08:00
Awni Hannun
516ded618b
Dynamic slicing (#1741)
* dynamic slice and slice update

* python bindings + tests + fix set item

* fix compile issue

* comment

* fix jit
2025-01-07 14:02:16 -08:00
Awni Hannun
ae69cb15e9
shapeless compile in docs and partially shapeless reshape (#1742) 2025-01-02 16:24:42 -08:00
Awni Hannun
4ba0c24a8f
Export / import functions to / from a file (#1642)
* export and import functions

* refactor + works for few primitives

* nit

* allow primitives with state

* nit

* nit

* simplify serialize / deserialize

* fix for constants

* python bindings

* maybe fix serialize failure case

* add example

* more primitives, training kind of works

* same result for python and c++

* some fixes

* fix export

* template it up

* some simplificatoin

* rebase

* allow kwargs and multiple functions

* exporter

* more primitives for exporting

* deal with endianness

* handle invalid stream

* add docstring
2024-12-24 11:19:13 -08:00
Awni Hannun
ebfe64b92d
shapeless slice update and broadcast when possible (#1727) 2024-12-23 11:25:15 -08:00
Awni Hannun
e03f0372b1
More shape type (#1705)
* more shape type

* fix
2024-12-19 08:08:20 -08:00
Awni Hannun
d03c01dfbc
fix unflatten vjp (#1708) 2024-12-16 18:37:57 -08:00
Awni Hannun
4e1e9520e1
Flatten and unflatten (#1692)
* flatten and unflatten

* fix grad

* fix shape infer

* use squeeze + unsqueeze in get_item
2024-12-11 21:51:37 -08:00
Awni Hannun
f76a49e555
ExpandDims primitive (#1687)
* add squeeze primitive

* simplify squeeze, use in gather

* fix

* fix

* fix

* fix

* fix no cpu

* use squeeze in matmul and friends

* expand dims primitive

* comment
2024-12-10 16:39:07 -08:00
Awni Hannun
40c62c1321
Use int64 stride everywhere (#1671)
* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
2024-12-09 11:09:02 -08:00
Cheng
d0f471cff7
Using math defines requires switch in MSVC (#1665)
* Using math defines requires switch in MSVC

* Fix more math macros

* Fix type

* Remove _MSC_VER guard for math defines
2024-12-08 08:16:28 -08:00
Awni Hannun
d0b6cb0425
More primitives for compiling with shapeless (#1653)
* more shapeless and more Shape

* more shape

* fix

* fix
2024-12-06 11:29:18 -08:00
Awni Hannun
dcca0d7477
contiguous op / prim (#1612) 2024-11-21 19:51:49 -08:00
Angelos Katharopoulos
5e89aace9b
Fix concatenate vmap (#1600) 2024-11-19 10:44:04 -08:00
Awni Hannun
54f05e7195
Fix gather vmap (#1563)
* fix gather

* fix
2024-11-05 11:29:20 -08:00
Kashif Rasul
3ddc07e936
Eigenvalues and eigenvectors (#1334)
* initial eigvalsh

* add compute_vectors

* add compute_vectors_

* return a pair

* add eigh to return only eigenvectors

* fixed typo

* merge merge Eighvalsh and Eigh into a single primitive

* use the same primate with the flag

* fix primatives

* use MULTI

* fix eval_gpu

* fix decleration

* rename EighPrimitive to Eigh

* tests

* tests

* fix rebase and format

* cleanup lapack

* format

* add cblas.h

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-10-22 12:18:48 -07:00
Awni Hannun
3f86399922
Real and Imag (#1490)
* real and imag

* fix

* fix
2024-10-15 16:23:15 -07:00
Awni Hannun
e4534dac17
Conv grad with groups + bugfix (#1449)
* fix bug in flipped conv with groups, start of grad for groups

* fix

* fix

* fix + test
2024-10-06 07:08:53 -07:00
Awni Hannun
195b429d99
Put along axis + fixe for partition grad (#1430)
* put along axis, fixes for partition grad

* zeros for arg reduce
2024-09-23 10:03:38 -07:00
Max-Heinrich Laves
efeb9c0f02
Transposed Convolution (#1245)
* initial implementation for conv_transpose

ran pre-commit

implemented conv_transpose

updated conv_general docstring

updated conv_general docstring

updated code comments

removed commented run_conv_checks

updated acknowledgments

added missing entry to ops.rst

added op to nn.layers

resolved merge conflicts

* removed ConvolutionTranspose primitive as suggested by reviewer

removed ConvolutionTranspose primitive as suggested by reviewer

* remove transpose flag, add another test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-06 19:52:38 -07:00
Alex Barron
635ccd9e25
Add "edge" mode to mx.pad (#1309)
* Add edge padding mode

* fix pad in pooling

* string arg instead of enum
2024-08-06 11:23:10 -07:00
nicolov
8c9f0278b9
Add vmap to scatter (#1200)
* Add vmap to scatter

* updates

* vmap updates + a few more tests

* bug fix

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-08-05 20:12:27 -07:00
Feng Shijie
987785d8d7
Fix typo and missing header (#1266) 2024-07-15 08:20:24 -07:00
Angelos Katharopoulos
5c1fa64fb0
Custom transforms (#1246) 2024-07-10 18:00:01 -07:00
Alex Barron
a3c287354f
Fast Hadamard Transform (#1249)
* Working hadamard for powers of 2

* working for m*2^k

* add scale and check contiguity

* add size check

* clean up

* fix test

* add grads + vmap

* gpu only

* skip on linux

* test typo

* add cpu impl

* remove gpu only tests

* fix linux build + add is_equivalent
2024-07-09 20:39:01 -07:00
Alex Barron
bdb36c9a63
add zero vjps for bitwise ops and gather w.r.t. index (#1256) 2024-07-07 21:34:59 -07:00
Jagrit Digani
2d6cd47713
Masked gemv (#1211) 2024-06-14 09:52:26 -07:00
Fangjun Kuang
f20e97b092
minor fixes (#1194)
* minor fixes

* fix build errors
2024-06-12 22:06:49 -07:00
Awni Hannun
ea9090bbc4
Add view op (#1179)
* add view primitive

* nit

* fix view
2024-06-04 08:05:27 -07:00
Awni Hannun
fd1c08137b
stable cumprod grad at 0 (#1167) 2024-05-31 12:28:42 -07:00
Jagrit Digani
eab2685c67
Float mask update (#1152)
* Float mask update

* Update CPU impl
2024-05-23 17:20:44 -07:00
Awni Hannun
226748b3e7
JIT compile option for binary minimization (#1091)
* try cpp 20 for compile

* unary, binary, ternary in jit

* nits

* fix gather/scatter

* fix rebase

* reorg compile

* add ternary to compile

* jit copy

* jit compile flag

* fix build

* use linked function for ternary

* some nits

* docs + circle min size build

* docs + circle min size build

* fix extension

* fix no cpu build

* improve includes
2024-05-22 12:57:13 -07:00