Commit Graph

343 Commits

Author SHA1 Message Date
Jagrit Digani
8c2e15e6c8
Accelerate import updates for iOS (#1227)
* Update veclib and bnns includes to #include <Accelerate/Accelerate.h> for compatibility with ios

* Mark float literals in softmax.cpp to be float16_t for errors in ios

* Add arm neon vector operation guards

* Redirect to common backend for consistency
2024-06-26 09:01:50 -07:00
Awni Hannun
56c8a33439
Get metal version from xcode (#1228)
* get metal version from xcode

* typo

* fix
2024-06-26 07:02:11 -07:00
Jagrit Digani
2d6cd47713
Masked gemv (#1211) 2024-06-14 09:52:26 -07:00
Awni Hannun
fe3167d7ea
smaller CPU binary (#1203)
* smaller CPU binary

* fix no cpu build
2024-06-14 09:46:55 -07:00
Awni Hannun
31e134be35
Build for macOS 15 (#1208)
* Build for macos 15

* metal32 as well

* comment

---------

Co-authored-by: Awni Hannun <Awni Hannun>
2024-06-13 13:31:44 -07:00
Fangjun Kuang
f20e97b092
minor fixes (#1194)
* minor fixes

* fix build errors
2024-06-12 22:06:49 -07:00
Alex Barron
934683088e
Refactor JIT for unary/binary/ternary ops (#1206)
* refactor unary/binary/ternary ops

* get_primitive_string util

---------
2024-06-12 14:22:12 -07:00
Awni Hannun
de2b9e7d0a
Fix kernel deps to reduce build times (#1205) 2024-06-12 11:17:39 -07:00
Alex Barron
dd7d8e5e29
Add Quantized Ops to the JIT (#1204)
* JIT for quantized ops

* remove unused imports

* address comments

* fix imports

* second attempt to fix imports

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-12 09:47:12 -07:00
Awni Hannun
df964132fb
fix scatter + test (#1202)
* fix scatter + test

* fix test warnings

* fix metal validation
2024-06-11 14:35:12 -07:00
Alex Barron
27d70c7d9d
Feature complete Metal FFT (#1102)
* feature complete metal fft

* fix contiguity bug

* jit fft

* simplify rader/bluestein constant computation

* remove kernel/utils.h dep

* remove bf16.h dep

* format

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-06 12:57:25 -07:00
Awni Hannun
578842954c
fix jit scan when output doesn't have primitive (#1190) 2024-06-06 07:24:58 -07:00
Awni Hannun
496315fe1d
Fix scan (#1188)
* fix scan

* improve grid size

* fix cpu cummax
2024-06-05 14:21:58 -07:00
Awni Hannun
83b11bc58d
Fix Metal API validation for empty concat (#1183) 2024-06-04 13:17:08 -07:00
Alex Barron
375a8bbdcc
Add some internal GPU apis (#1177)
* Add unary/binary/ternay/slice/concat internal GPU ops

* add pad internal op

* formatting + no_cpu fix
2024-06-04 09:24:26 -07:00
Awni Hannun
ea9090bbc4
Add view op (#1179)
* add view primitive

* nit

* fix view
2024-06-04 08:05:27 -07:00
Alex Barron
4d485fca24
Add defines include (#1176)
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-03 09:50:10 -07:00
Brian Keene
1865299a30
Metal shaders for memory efficient self attention on large sequences (#964)
* Metal shaders for efficient self attention on large sequences

Updated fast attention: GEMM-ified with Steel primitives
Uses flash attention 1 for scale correction

* more compiler silencing

* Address rebase issues

* Templatize kernel instantiation, revise cpu bindings

* Safer writes to output

* Permit batch size > 1

* Numerical fixes for sdpa self attention

* Re-enable test, remove unused variable

* add benchmarking script

* Disable sdpa prior to perf tuning, and simplify tests for per-patch CI
2024-06-03 09:16:19 -07:00
Jagrit Digani
76b6cece46
Fix multi-block sort stride management (#1169)
* Fix multi-block sort stride management

* Add seed to tests
2024-05-31 11:10:54 -07:00
Jagrit Digani
9f0df51f8d
Fix matvec vector stride bug (#1168) 2024-05-29 12:18:28 -07:00
Awni Hannun
e7a2a3dcd1
Fix a couple bugs (#1161)
* fix jit reduce for RMS norm

* make strides a single buffer

* better eval error message

* fix compiling with inf and bf16

* fix cpu compile with bf16
2024-05-28 15:18:18 -07:00
Awni Hannun
7e26fd8032
Option to JIT steel gemm / conv (#1139) 2024-05-23 18:07:34 -07:00
Jagrit Digani
eab2685c67
Float mask update (#1152)
* Float mask update

* Update CPU impl
2024-05-23 17:20:44 -07:00
Awni Hannun
0189ab6ab6
More jitting (#1132)
* docs + circle min size build

* jit scan, arange, softmax

* add sort

* jit reductions

* remove print

* fix deps

* clean includes / nits
2024-05-23 16:23:44 -07:00
Rifur13
9401507336
Add groups to 2-D convolutions (#1129)
* Added groups to 2-D convolutions. Only implemented for **some** specializations.

Also fixed 1D grouped convs with different kernel strides and added more tests.

* fix channels condition
2024-05-22 20:01:44 -07:00
Awni Hannun
e110ca11e2
Fix offset bug for device buffers (#1151)
* fix bug with large offsets for buffers

* add a test

* remove test as its too big for small machine
2024-05-22 15:50:05 -07:00
Awni Hannun
226748b3e7
JIT compile option for binary minimization (#1091)
* try cpp 20 for compile

* unary, binary, ternary in jit

* nits

* fix gather/scatter

* fix rebase

* reorg compile

* add ternary to compile

* jit copy

* jit compile flag

* fix build

* use linked function for ternary

* some nits

* docs + circle min size build

* docs + circle min size build

* fix extension

* fix no cpu build

* improve includes
2024-05-22 12:57:13 -07:00
Awni Hannun
d568c7ee36
Rename block sparse (#1149)
* block_sparse_mm to gather_mm

* rename

* nit

* nit
2024-05-22 07:48:34 -07:00
Angelos Katharopoulos
da83f899bb
Improve qvm speed (#1140) 2024-05-20 09:20:44 -07:00
Awni Hannun
fb71a82ada
Fix copy bug with many dims (#1137) 2024-05-17 21:10:03 -07:00
Awni Hannun
23406c9e9e
Choose the right MLX bf16 for extensions (#1135)
* default to custom bf

* choose right bf

* fix extensions

* fix circle conf
2024-05-17 15:09:28 -07:00
Luca Arnaboldi
b3ec792380
Implemented Cholesky on CPU (#1119) 2024-05-17 12:31:59 -07:00
Angelos Katharopoulos
e78a6518fa
Block sparse qmm (#1124) 2024-05-16 15:24:14 -07:00
Awni Hannun
1873ffda01
Detect metal version and propagate correctly for JIT (#1109)
* detect metal version and propagate correctly for JIT

* remove softmax

* fix versions
2024-05-15 17:42:09 -07:00
Jagrit Digani
358e1fd6ab
Fused GEMM (#1123)
* Basic gemm working

* Update addmm

* Clear out steel_gemm and steel_addmm kernels

* Fuse and clear out gather gemm

* Update objc releases
2024-05-15 10:30:41 -07:00
Awni Hannun
863039da4c
Allow scatter type exception to be caught by checking in op (#1077)
* allow exception to be caught in main thread

* only for gpu

* more detailed scatter error
2024-05-13 17:43:53 -07:00
Awni Hannun
7178ac0111
No CPU option for binary minimization (#1105)
* no cpu build option

* docs

* fix
2024-05-13 16:08:11 -07:00
Max-Heinrich Laves
ff4223904d
Conv3d (#993)
* added conv3d

added conv3d

implemented explicit_gemm_conv_ND_cpu and bounds checks for slow_conv_3D

* incorporated reviewer comments

* fixed test

* reduced tensor shapes in test for conv3d

* Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion
2024-05-11 06:15:02 -07:00
Alex Barron
2e158cf6d0
Add conjugate operator (#1100)
* cpu and gpu impl

* add mx.conj and array.conj()

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-05-10 07:22:20 -07:00
Awni Hannun
06375e6605
Split encoders in non-concurrent context with a max ops per encoder (#1085)
* split encoders

* fix race
2024-05-09 16:21:02 -07:00
Rahul Yedida
cc05a281c4
Added ArcTan2 operation (#1079)
* Added ArcTan2 operation

* Cleanup, bug fixes from code review

* Minor cleanup, fixed Linux tests
2024-05-08 08:35:15 -07:00
Jagrit Digani
fe96ceee66
Update block offset adjustment to be in size_t (#1087) 2024-05-08 08:10:23 -07:00
Awni Hannun
21623156a3
Reset peak memory (#1074)
* reset peak memory

* fix linux

* nits in docs
2024-05-03 17:12:51 -07:00
Awni Hannun
b00ac960b4
change initial memory limits and add memory size to device info (#1064) 2024-05-03 06:50:15 -07:00
Jagrit Digani
f390957685
Block sparse mm (#1058) 2024-05-02 14:03:58 -07:00
Angelos Katharopoulos
17f57df797
Improvements in the quantizer and dequantization kernel (#1061) 2024-05-01 18:19:11 -07:00
Awni Hannun
7f7b9662ea
Fix leak for multi-output primitives which are never detached (#1059)
* fix multi output leak

* ignore arrays that will be detached

* add some comments

* stray print
2024-05-01 07:31:45 -07:00
Awni Hannun
19bef39f5c
Add a mx.metal.device_info (#1060)
* device inof

* add variant

* fix linux

* fix doc
2024-04-30 15:47:27 -07:00
Nripesh Niketan
a30e7ed2da
feat: metal formatting and pre-commit bump (#1038)
* feat: metal formatting and pre-commit bump

* add guards

* update

* more guards

* more guards

* smakk fix

* Refactor instantiation of ternary types in ternary.metal

* fix scan.metal
2024-04-30 07:18:09 -07:00
Awni Hannun
09f1777896
fix slice update indexing (#1053) 2024-04-29 12:17:40 -07:00
Rifur13
c4a471c99d
Add groups to Conv1d (#948)
* Add conv1d grouped convs on CPU

* Add GPU support

* Parallelize inside metal kernel

* clenaup

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* New unfold kernel + remove unused code

* Remove copy and refactor

* Update vjp and reuse steel gemm

* Fixed groups on cpu

* Fix metal validation

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-04-27 06:24:57 -07:00
Awni Hannun
86f495985b
Add bitwise ops (#1037)
* bitwise ops

* fix tests
2024-04-26 22:03:42 -07:00
Awni Hannun
5bfe89bdb1
Cpp docs (#1036)
* start of C++ docs

* fix stream doc

* only include ops for now
2024-04-26 12:56:05 -07:00
Awni Hannun
771575d27b
Expose function to clear memory cache (#1032)
* expose function to clear memory cache

* fix linux build

* fix metal tests
2024-04-24 16:48:51 -07:00
Angelos Katharopoulos
20a01bbd9f
Simplifying and improving qmm (#1030) 2024-04-24 13:07:45 -07:00
Awni Hannun
3d405fb3b1
Add synchronize function (#1006)
* add synchronize function

* fix linux

* fix linux

* fix and fix docs

* fix test

* try synchronize in stream destroy

* synchronize works for both cpu and gpu
2024-04-22 08:25:46 -07:00
Jagrit Digani
85c8a91a27
Fix mask broadcasting bug and add relevant test (#1003) 2024-04-17 17:33:48 -07:00
Awni Hannun
8a0677d56d
Shared events for synchronization + async eval (#998)
* more async eval

* fix rebase

* try correct async eval

* fix async

* more tests for async eval

* use shared events for synchronization

* comment + cleanup

* with autorelease pool

* fix no metal build

* fix compile

* fix patch

* don't eval if asyn evale'd

* don't use is_evaled

* comments

* more multi stream tests

* try and cleanup use of is_evaled

* use a status flag
2024-04-17 06:16:02 -07:00
Jagrit Digani
b18468bf81
Masked mm (#978)
* Add block masked matmul op and primitive
2024-04-16 14:45:39 -07:00
Alex Barron
2e7c02d5cd
Metal FFT for powers of 2 up to 2048 (#915)
* add Metal FFT for powers of 2

* skip GPU test on linux

* fix contiguity bug

* address comments

* Update mlx/backend/metal/fft.cpp

* Update mlx/backend/metal/fft.cpp

* fix bug in synch

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-11 21:40:06 -07:00
Awni Hannun
ae18326533
No copy command encoder (#986)
* no copy command encoder

* up layer norm test tolerances
2024-04-11 21:15:36 -07:00
Nripesh Niketan
ffff671273
Update pre-commit hooks (#984) 2024-04-11 07:27:53 -07:00
Awni Hannun
12d4507ee3
Explicit barriers with concurrent dispatch (#977) 2024-04-10 21:45:31 -07:00
Awni Hannun
42afe27e12
std and expm1 (#973)
* std and expm1

* actually add expm1

* fix linux

* fix vjp

* relax tol for linux test

* Add it to the compilable primitives

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-04-08 14:26:01 -07:00
Awni Hannun
76e63212ff
Enable bfloat scan (#974)
* enable bfloat scan
* fix tests
2024-04-08 12:29:19 -07:00
Awni Hannun
aac2f9fb61
Improve profiling with gpu tracing (#969)
* improve profiling with gpu tracing

* fix for linux

* nit

* doc fix

* fix example
2024-04-07 21:47:43 -07:00
Awni Hannun
d88d2124b5
segfaut layer norm grad (#955) 2024-04-04 10:59:15 -07:00
Awni Hannun
e142aaf8a1
Option for precise softmax (#953)
* precise softmax

* Add an equivalency check

* Make the threadgroup memory definition fixed

* precise cpu softmax

* precise option on cpu

* remove print

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-04-04 08:32:35 -07:00
Awni Hannun
2427fa171e
Fix cpu compile (#934)
* fix one cpu bug, test for another

* format hooks

* simplify contiguity check for cpu compile

* fix

* add back donation

* comment
2024-04-01 17:37:12 -07:00
Angelos Katharopoulos
110d9b149d
Layer norm grad fix donation bug (#941)
* add layer norm grad test

* Fix donation bug in layernorm vjp

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-01 06:15:50 -07:00
Angelos Katharopoulos
9cbff5ec1d
Fix typo in qmm check (#940) 2024-03-31 19:15:44 -07:00
Awni Hannun
8915901966
Donation bug (#933)
* donation

* buf

* fix bug in softmax

* comment

* remove print
2024-03-30 10:08:54 -07:00
Cheng
913b19329c
Add missing && when forwarding args (#925)
Without the && args would be copied and perfect forwarding won't work.
2024-03-29 06:48:29 -07:00
Angelos Katharopoulos
5f9ba3019f
Fix qmm_t for unaligned cases (#923) 2024-03-28 15:34:57 -07:00
Jack Mousseau
45f636e759
Add Metal debug option and capture functions (#707)
* Add Metal debug option and capture functions

* Add brief Metal debugger documentation

* doc nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-28 09:40:31 -07:00
Angelos Katharopoulos
aca7584635
Fix OOB read in qmv when non-divisible by blocksize (#917) 2024-03-27 22:18:35 -07:00
Angelos Katharopoulos
29221fa238
Implement vjps for some primitives in the fast namespace (#883)
* Implement rope vjp in terms of rope
* RMSNormVJP primitive and kernel
* Add LayerNormVJP primitive and kernel
2024-03-26 16:35:34 -07:00
Jagrit Digani
925014b661
Fix multiblock sort limits (#906)
* Fix multiblock sort limits

* Fix metal validation error
2024-03-26 14:00:00 -07:00
Angelos Katharopoulos
9948eddf11
Fix nan and improve speed for qvm (#903) 2024-03-26 10:41:45 -07:00
Cheng
28fcd2b519
Add missing && when forwarding args (#894)
Without the && args would be copied and perfect forwarding won't work.

Also add template utils to make sure the function only forwards array
and not vector<array>.
2024-03-25 14:55:54 -07:00
Daniel Strobusch
479051ce1c
add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)
* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate

numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)).

Closes #285.

* nits in docs

* unify type category checking

* nits in docs

* nits in docs

* more docs nits

* fix callable type

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-25 12:32:59 -07:00
Angelos Katharopoulos
6ee1112f30
Fix copy donation and add partial rope (#881) 2024-03-22 17:28:26 -07:00
Angelos Katharopoulos
2225374060
Adds mx.fast.layer_norm (#870) 2024-03-21 13:55:51 -07:00
nicolov
105d236889
Add vmap for SVD and inverse (#849) 2024-03-21 13:18:27 -07:00
Awni Hannun
a54f06b16f
Fast RMS Norm (#862)
* fast rmsnorm

* no rms gpu

* kernel

* fix shared mem

* looped rms and donation in softmax

* Make the squaring in float32 to avoid underflow

* Fix the default StreamOrDevice for rope and rms_norm in fast

* nits

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-03-21 07:20:54 -07:00
Jagrit Digani
b219d12a6b
Check edge case handling in row reduce med kernel (#858) 2024-03-20 11:37:58 -07:00
Jagrit Digani
cec8661113
Add a SliceUpdate op and primitive (#850)
* Enable copy to work with int64 strides
* Fix uniform buffer indices or copy kernel arguments
* Update utils.h
* Remove manual unrolling of elem to loc loop
* GPU copy updated to handle negative strides
* Add slice update primitive
2024-03-20 10:39:25 -07:00
Cheng
d39ed54f8e
Some C++ code are not needed (#841)
1. Anonymous namespace means internal linkage, static keyword is not needed.
2. The default constructor of std::shared_ptr initializes the pointer to
   nullptr, you don't need to explicitly set it.
2024-03-18 17:04:10 -07:00
Awni Hannun
16546c70d8
No reshape rope (#838)
* no reshape rope

* no reshape rope
2024-03-18 17:03:07 -07:00
nicolov
eaba55c9bf
Add matrix inversion primitive (#822) 2024-03-15 06:34:36 -07:00
Jagrit Digani
8dfc376c00
Strided reduce specialization for small reductions (#826)
* Add small column / general reduction specialization
2024-03-14 09:16:53 -07:00
Angelos Katharopoulos
1efee9db09
Add types and order in kernel name (#831) 2024-03-13 20:34:06 -07:00
Angelos Katharopoulos
3f8b1668c4
Make reshape faster for row_contiguous cases (#829) 2024-03-13 16:22:03 -07:00
Angelos Katharopoulos
76c919b4ec
NumberOfElements for shapeless compile and vmap fixes (#802) 2024-03-13 10:34:14 -07:00
Angelos Katharopoulos
29d0c10ee5
Reshape improvement (#818) 2024-03-12 17:54:31 -07:00
Jagrit Digani
5ad133f8bb
No copy gems (#801)
* Enable collapsing batch dims in gemm
* Update gemm to only make copies when neither of the last 2 axes are contiguous
* Update addmm to support gemv shapes
* Update addmm to support irregular batch strides
* Update tests
2024-03-12 13:13:41 -07:00
nicolov
d0c544a868
Add SVD primitive (#809)
Add SVD op using Accelerate's LAPACK following
https://developer.apple.com/documentation/accelerate/
compressing_an_image_using_linear_algebra

Co-authored-by: Nicolo Valigi <nvaligi@apple.com>
2024-03-12 12:30:11 -07:00
Awni Hannun
8b7532b9ab
fix scatter (#821) 2024-03-12 11:42:07 -07:00
nicolov
0ae22b915b
Remove code duplication in reduce ops (#793)
* Remove code duplication in reduce ops

* Remove the unnecessary lambda

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-03-11 10:57:07 -07:00
Awni Hannun
7c441600fe
Compile stride bug (#812)
* fix compile stride bug

* revert sdpa fix

* fix cpu

* fix bug with simplifying outputs
2024-03-11 06:31:31 -07:00
Jagrit Digani
ec8a4864fa
Fix SDPA kernel bug on Mac OS 13.3 SDK (#805)
* Move sdpa kernel to allocate tgp mem statically and allow macOS 13.3 SDK builds

* Style
2024-03-07 10:18:09 -08:00
Awni Hannun
f512b905c7
Minimum xcode / sdk (#800)
* minimum xcode /sdk

* try multiple xcode versions in CI

* update python

* metal validation for python tests
2024-03-07 08:19:43 -08:00
Angelos Katharopoulos
14b4e51a7c
Improved quantized matrix vector product (#786) 2024-03-05 17:32:19 -08:00
Awni Hannun
cbcf44a4ca
Some fixes in cache / thread safety (#777)
* some fixes in cache / thread safety

* speed up no cache case

* fix opt test

* optimizer docs

* otpimizer docs

* fix adafactor

* fix adafactor
2024-03-05 13:30:50 -08:00
Brian Keene
0787724c44
Fast Inference SDPA op (#735)
* Fast Inference SDPA op

Implements metal shaders for:

o = mx.fast_inference_sdpa(queries, keys, values, scale, mask)

Supports fp16, fp32 dtypes; assumes d_k = 128.

Generic op support / prompt encoding supported via mlx primitives.
Metal implementation is for the inference use case only.

Majority of performance benefits appears to results from GQA & reduced
bandwidth requirements; there is approximate performance parity for the
MHA use case (from some measurements on M3 Max).

* Flush shared memory to zero before unprotected reads for (scores @ values)

* Move to fast:: namespace, address reviewer comments

... also attempt to revert formatter auto-change for files not relevant
to this change

* Shared memory flush to top of kernel

* Resolve compiler warnings

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update docstring per PR feedback

* Softmax in higher precision, ...

* route to fallback for more use cases - batch size > 1, head_dim other
  than 128, etc.
* Address linux build failure
* Address other reviewer comments

* Remove extraneous eval_cpu function per review

---------

Co-authored-by: Atila Orhon <64497909+atiorh@users.noreply.github.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: atila <atiorh@icloud.com>
2024-03-04 21:06:11 -08:00
Awni Hannun
7b463ffb07
Ios compile (#784)
* try to fix build for ios

* skip cpu compile

* fix namespace

* fix namespace

* Use CMake for platform specific cpu compile

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-03-04 20:02:26 -08:00
Jagrit Digani
6686e61ca4
Reduce update (#783)
* Split reduction files to reduce compile times

* Add small and medium axis size specializations for row reductions

* Add non-row-reduction options for small and med kernels
2024-03-04 19:09:51 -08:00
Awni Hannun
d5964a2710
bindings for memory info (#761)
* bindings for memory info

* update api

* keep cache low if requested

* fix default

* nit in ops error
2024-03-01 19:51:58 -08:00
Jagrit Digani
776c3d226d
Convolution update (#651)
* Init steel conv and update Conv primitive

* Update slow CPU implementation to support flipping and input dilation winograd conv routing

Co-authored-by: Awni Hannun <awni@apple.com>
2024-02-28 20:11:16 -08:00
Awni Hannun
f5f18b704f
fix temporary bug (#752) 2024-02-27 17:44:39 -08:00
Awni Hannun
56ba3ec40e
fix cpu compile on older OS (#747) 2024-02-26 22:20:53 -08:00
Awni Hannun
e6418781ab
Fix logsumexp edge case (#740)
* fix logsumexp

* fix inf constant

* also fix power grad

* fix ternary dispatch
2024-02-25 08:39:55 -08:00
Awni Hannun
ac02cf33bd
Fix some issues using MLX in C++ (#739)
* fix preamble build

* fix some issues with using MLX as a dep in C++
2024-02-24 22:20:57 -08:00
Rifur13
126c9869c8
Implement the 'where' primitive for conditional selection (#664) 2024-02-22 15:10:48 -08:00
Jagrit Digani
884b4ed43b
Fix threadgroup memory in arg reduce (#723) 2024-02-21 19:42:16 -08:00
Vijay Krish
972d9a3aea
Up to 10x faster scatter. (#709)
* Faster scatter.

Add specialization for 1-d index tensors.

* Address review comments.

- Check for row contiguity of index, update tensors
  instead of checking strides.
- Add support for 1d specialization with col contiguous update
  tensor, along with a test.

* Nit1

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Nit2

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-02-21 11:09:30 -08:00
Awni Hannun
5798256fcf
Shapeless compilation for some graphs (#687)
* shapeless compilation for some graphs

* update compile benchmark

* default compile a few activations

* buffer donation

* bugfix

* shapeless fix

* update tests to work for cpu and gpu fusion

* test kwargs

* add kwargs to compile

* Recompile when python arguments change

* no compile for tanh

* some constant tests

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-19 21:43:54 -08:00
Awni Hannun
1a4f4c5ea6
Refactor CPU compile preamble (#708)
* refactor cpu preamble

* fix include order

* fix some issues'

* fixes for linux

* try to fix includes

* add back warning suppression

* more linux fixes
2024-02-19 06:12:53 -08:00
Jack Mousseau
0925af43b0
Remove unused variables (#706) 2024-02-18 12:50:10 -08:00
Awni Hannun
dc937b8ed3
CPU compile (#691)
* build and load shared object for cpu compile

* nits

* cpu compile tests pass

* cpu compile tests pass

* fix preamble for g++

* donation

* fix gpu buffer donation

* reuse prebuilt libraries

* faster contiguity conditoins

* fix test

* rid compiler warning

* fast erf

* Fix float16 for compile and add more types to cpu compile

* Remove a forgotten comment

* use cached libs

* nits

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-17 06:54:32 -08:00
Awni Hannun
c3965fc5ee
Separate fast ops and primitives (#699) 2024-02-16 19:16:39 -08:00
Awni Hannun
ccf1645995
Custom primitive + RoPE fat op (#676)
* extensions start

* rope custom op

* fix build

* docs + rope benchmark

* fix test

* Add a Metal kernel for RoPE

* Fix position of traditional

* transform tests

* Move rope computation to float and fix tests

* Fix the test and a typo

* change to fast

* fix no metal build

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-14 14:04:25 -08:00
Jagrit Digani
1a48713d32
Update gather and scatter to not use Argument Encoder (#683)
* Replace argument encoder usage for gather and scatter

* Use constant address space for shapes and strides

* Split gather and scatter to improve compile times

* Enable the GPU tests

* Update the CI config

* Fix scatter dispatch for scalar indices

* Remove arg encoder utils

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-14 13:42:13 -08:00
Vijay Krish
2fdc2462c3
Faster gather and scatter. (#682)
Reduce unnecessary integer ops, especially since
there kernels are integer bound.

Increase number of iterations for benchmarks for
better smoothing.

Github Issue #506

Co-authored-by: Vijay Krishnamoorthy <vijay_krish@apple.com>
2024-02-13 17:47:41 -08:00
Angelos Katharopoulos
40c108766b
Quantized matmul fix (#677)
* Fix qmv for small or unaligned matrices

* Fix qmm
2024-02-12 18:54:21 -08:00
Awni Hannun
3756381358
Faster bfloat quantized mat-vec and vec-mat (#663) 2024-02-11 21:53:16 -08:00
Awni Hannun
d12573daa6
quote file name (#670) 2024-02-11 10:33:30 -08:00
Vijay Krish
06072601ce
Scatter optimization : Eliminate 64b integer divide. (#662)
Launch 2D grid to eliminate divide and mod in device code,
since 64b integer division is very expensive.

Github Issue #506

Co-authored-by: Vijay Krishnamoorthy <vijay_krish@apple.com>
2024-02-10 08:49:51 -08:00
Awni Hannun
7f3f8d8f8d
Fix the softmax fix (#661) 2024-02-09 17:02:13 -08:00
Awni Hannun
b96be943dc
bug fix (#658) 2024-02-09 16:50:45 -08:00
Abdussamet Türker
b670485185
Remainder negative numerator bug fixed (#641)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-09 16:49:14 -08:00
Angelos Katharopoulos
28eac18571
Kernel generation (#614)
Generate reusable element-wise kernels given a computation graph.
2024-02-07 13:15:59 -08:00
Jagrit Digani
316ff490b3
Remove masks from BlockLoader and clear out load case for invalid thread (#634) 2024-02-05 16:00:17 -08:00
Awni Hannun
d40a04f8dc
minor fixes (#631)
* minor fixes

* var with ddof >= nelements
2024-02-05 13:27:49 -08:00
Awni Hannun
d75ae52ecd
Compile primitive (#571)
* Compiled primitive with basic binary, unary graph-level fusion
2024-02-05 06:51:22 -08:00
Awni Hannun
e319383ef9
Faster gather (#626)
* faster gather

* update copyright
2024-02-04 17:25:44 -08:00
David Koski
ebfd3618b0
fixes for building and running on iOS (#619)
* fixes for building and running on iOS

* per suggestion just use Accelerate
2024-02-04 12:29:17 -08:00
Awni Hannun
cb6156d35d
Fix eval in trace bugs (#612)
* Fix eval in trace bugs

* comment nit
2024-02-02 09:57:12 -08:00
Vijay Krish
fcc5ac1c64
Add GPU support for uint64/int64 reductions (#569) 2024-01-31 11:18:04 -08:00
Angelos Katharopoulos
0de5988f92
Custom VJP and checkpointing (#541)
* Implement custom_vjp and checkpointing
* Add a dependency management primitive
* Change the eval order to deep branches first
* Add graph depth tracking to the array
2024-01-30 16:04:45 -08:00
Jagrit Digani
375446453e
Update Compute Pipeline Creation API (#581)
* Add option to specialize metal functions on function constants
* Update Compute Pipeline Creation API
* Add options to make libraries from source and stitching
* Update function specialization name options
2024-01-30 15:42:36 -08:00
Angelos Katharopoulos
1895d34c20
Fix log1p with inf inputs (#592) 2024-01-30 14:02:50 -08:00
Angelos Katharopoulos
65d0b8df9f
Fix binary op dispatch (#584) 2024-01-29 19:36:17 -08:00
Awni Hannun
3c2f192345
Propagate nans in binary ops (#579)
* propagate nans in binary ops

* handle empty matmul

* cpu minimum/maximum propagate nan

* benchmark maximum

* add min as well

* throw on negative indices with full

* verbose on linux

* fix matmul for zero K
2024-01-29 11:19:38 -08:00
Awni Hannun
8993382aaa
Buffer Donation (#519)
* buffer donation

* fix to move shared pointer

* format

* gpu in place for copy and binary

* revert ops test

* cpu in place

* a little cleanup

* remove useless bench
2024-01-26 16:30:33 -08:00
taher
077c1ee64a
QR factorization (#310)
* add qr factorization

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-26 09:27:31 -08:00
Jagrit Digani
6d3bee3364
Fix oob reads in gemv kernel (#523) 2024-01-22 12:06:04 -08:00
Awni Hannun
7a34e46677
Quantize with groups of 32 (#511)
* allow quantize with group sizes of 32

* missing cpu dispatch

* remove print

* Fix qvm for group_size 32

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-01-21 06:19:05 -08:00
Awni Hannun
3d99a8d31d
Fix format / build (#489) 2024-01-18 10:01:59 -08:00
Ethan
a749a91c75
Support disable metal buffer cache to prevent performance degradation caused by large memory caching (#390)
* support disable metal buffer cache, due to large unused memory buffered when llm generated long context tokens

* Run format and add "cache_enabled" feature tests
2024-01-18 08:33:34 -08:00
Angelos Katharopoulos
90c234b7ac
Fix round to round half-cases to even (#482) 2024-01-17 15:27:23 -08:00
Angelos Katharopoulos
135fd796d2
Fix detach for multi-output primitives (#480) 2024-01-17 14:08:07 -08:00
Jagrit Digani
78102a47ad
Update GEMM (#424)
* Organize and collect metal subroutine templates and elements in `metal/kernels/steel/`
* Update gemm elements for better performance 
* Add split-K specialization for gemm
* Add `addmm` primitive, op and bindings for fused matmul and bias addition 
* Update tests and benchmarks as needed
2024-01-17 12:42:39 -08:00
Awni Hannun
275db7221a
Command buffer reports errors (#479)
* command buffer reports errors

* typo

* simplify
2024-01-17 11:53:30 -08:00
Angelos Katharopoulos
d8fabaa12b
Split multi output (#461)
* Multi-output split primitive
* Add the multi-output split to the ArrayIterator
* Add some grad tests for split
2024-01-16 13:33:55 -08:00
Awni Hannun
a2ffea683a
Fix eye for larger matrices (#463)
* fix eye
* fix scatter for <32bit (non native atomic) types
* fix int overflow
2024-01-16 00:51:24 -08:00
Angelos Katharopoulos
c15fe3e61b
Allow arbitrary first dimension in quantization kernels. (#458)
* Allow arbitrary first dim on qmm_t and qmv
* Allow arbitrary first dim on qmm and qvm
* Specialized aligned vs unaligned case
* Add more checks for valid quantizations
2024-01-16 00:46:21 -08:00
Awni Hannun
c9934fe8a4
Metal validation (#432)
* tests clear metal validation

* add cpp test with metal validation to circleci

* nit
2024-01-11 11:57:24 -08:00
Awni Hannun
f099ebe535
Multi output primitives (#330)
* Multi-output primitives

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-01-08 16:39:08 -08:00
Nripesh Niketan
73321b8097
feat: add logicalAnd and logicalOR (#386)
* feat: add logicalAnd and logicalOR

* run pre-commit

* Refactor logical_and and logical_or functions

* Add acknowledgement

* Add logical AND and logical OR operators

* Refactor logical_and and logical_or functions

* Add support for logical operators on bool arrays

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Add logical AND and OR operators for arrays and scalars

* Refactor vjp and jvp methods in primitives.cpp

* Add overloaded operators for logical AND and OR

* format

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-08 07:00:05 -08:00
Angelos Katharopoulos
a611b0bc82
Removes the retain_graph flag (#385)
* Adds global tracing flag
* Removes retain_graph in favor of is_tracer
2024-01-07 15:16:51 -08:00
Awni Hannun
c6d2878c1a
safely divide for 0 size inputs (#388) 2024-01-07 00:19:54 -08:00
Awni Hannun
b9e415d19c
bump pre commit and fix format (#373) 2024-01-04 16:28:52 -08:00
davidkoski
c82a8cc526
move all ObjC (via metal-cpp) interaction until post static initializers (#370)
* move all ObjC (via metal-cpp) interaction until post static initializers

- metal-cpp relies on static initializers to cache class and selector pointers
- code in mlx was using metal-cpp to set up NSAutoreleasePools during its own static init time
- but this code was silently failing as the class and selector pointers from metal-cpp were still nil

- defer the creation of NSAutoreleasePools until after static init time
- ensure that we have coverage where autorelease pools are needed

* Update device.cpp

remove commented code

* Update device.cpp

remove commented out code

* Update scheduler.h

update comment

* per discussion use the pool inside the task() -- this will be metal only, not needed for cpu

* Update allocator.cpp

move pool to release/alloc area
2024-01-04 16:12:00 -08:00
Angelos Katharopoulos
e7f5059fe4
Support for quantized matmul with w and w^T (#349)
* Add the metal qvm implementation
* Add qmm_n
* Add gradient wrt to input for quantized_matmul
2024-01-03 14:22:36 -08:00
Awni Hannun
99c80a2c8b
Memory allocation (#292)
* try alternative gc

* try no cache

* add forced swap

* remove cache for now

* add cache back

* change fit crtieria

* remove unused function

* nit in comment

* tune / fix allocation

* increase block limit to original
2024-01-02 11:59:19 -08:00
Josh Soref
44c1ce5e6a
Spelling (#342)
* spelling: accumulates

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: across

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: additional

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: against

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: among

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: array

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: at least

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: available

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: axes

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: basically

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: bfloat

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: bounds

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: broadcast

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: buffer

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: class

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: coefficients

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: collision

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: combinations

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: committing

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: computation

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: consider

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: constructing

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: conversions

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: correctly

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: corresponding

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: declaration

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: default

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: dependency

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: destination

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: destructor

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: dimensions

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: divided

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: element-wise

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: elements

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: endianness

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: equivalent

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: explicitly

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: github

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: indices

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: irregularly

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: memory

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: metallib

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: negative

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: notable

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: optional

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: otherwise

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: overridden

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: partially

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: partition

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: perform

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: perturbations

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: positively

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: primitive

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: repeat

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: repeats

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: respect

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: respectively

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: result

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: rounding

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: separate

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: skipping

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: structure

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: the

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: transpose

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: unnecessary

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: unneeded

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: unsupported

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

---------

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
2024-01-01 21:08:17 -08:00
Diogo
1f6ab6a556
Safetensor support (#215)
Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-27 02:06:55 -08:00
Angelos Katharopoulos
9e6b8c9f48
Refactor the reduction kernels (#277) 2023-12-24 14:47:57 -08:00
Awni Hannun
8b227fa9af
fix no metal build (#276) 2023-12-23 19:18:10 -08:00
Ronan Collobert
cd3616a463
Revisit autorelease memory pools (#260)
* make general autorelease pool part of metal device

* make things simpler

* no metal backend support

* new_memory_pool -> new_scoped_memory_pool
2023-12-22 11:01:26 -08:00
Awni Hannun
2118c3dbfa
fix (#255) 2023-12-21 18:18:41 -08:00
Awni Hannun
a002797d52
A temporary fix (#254) 2023-12-21 17:59:15 -08:00
Daniel Strobusch
794feb83df
support arange for bfloat16 (#245) 2023-12-21 14:33:43 -08:00
Angelos Katharopoulos
b3916cbf2b
Improve names of quantization arguments (#235)
* Change the default quantization group_size to 64
* Rename groups to group_size and width to bits
2023-12-20 16:53:53 -08:00
Angelos Katharopoulos
2807c6aff0
Implements divide for integer types and adds floor_divide op (#228)
* Add floor_divide
* Add floor_divide to the tests
* Add floor_divide to the docs
2023-12-19 20:12:19 -08:00
davidkoski
de892cb66c
fix for non-macos build issue on cblas.h (#227) 2023-12-19 17:01:59 -08:00
davidkoski
37024d899c
fixes for building with swiftpm (#225)
- clbas is part of veclib (compile failure)
- add SWIFTPM_BUNDLE #define to allow loading the metallib from a swiftpm resource bundle
2023-12-19 16:22:10 -08:00
Angelos Katharopoulos
dfa9f4bc58
An initial quantized matmul implementation (#205)
* Add quantized matvec
* Add quantized matrix matrix with 2nd matrix transposed
* Add quantized matmul tests
* Add a slow cpu quantized matmul
* Add a slightly faster vectorized cpu version
2023-12-18 23:18:57 -08:00
Angelos Katharopoulos
4d4af12c6f
Adds round op and primitive (#203) 2023-12-18 11:32:48 -08:00
Ronan Collobert
83f266c44c
Lazy metal_device_ initialization (#185)
This ensures it is defined when the Scheduler needs it.
2023-12-15 16:06:46 -08:00
Luca Arnaboldi
b93c4cf378
Floor and Ceil (#150)
* Implements Floor and Ceil Ops
2023-12-14 10:00:23 -08:00
Ikko Eltociear Ashimine
c3272d4917
Update conv.cpp (#145)
Peform -> Perform
2023-12-12 11:27:49 -08:00
Awni Hannun
71d1fff90a
Bug fix in metal binary kernel dispatch for large arrays (#125)
* bug fix

* format
2023-12-10 16:12:31 -08:00
Angelos Katharopoulos
600db7d754
Fix build on Xcode 14 (#116)
* Fix build on Xcode 14

* Style fixes
2023-12-10 06:58:52 -08:00
Angelos Katharopoulos
2b714714e1
Add the remainder op (#85)
* Add remainder in the C++ backend
* Add the python binding and test
2023-12-08 15:08:52 -08:00
Angelos Katharopoulos
209404239b
Fix the accelerate dispatch for the power op (#70)
- The exponent and base were swapped because accelerate is using
  exponent-base instead of base-exponent
- Fix also the test for binary ops as it was testing op(x, x) which
  couldn't catch ordering errors like that
2023-12-08 10:58:03 -08:00
Jagrit Digani
d518b3b6a5
Fix gemv broadcasting bug (#6)
* Fix broadcasting bug in gemv
* Add relevant tests in test_blas.py
2023-12-05 14:15:43 -08:00
Awni Hannun
db487e6b1a format 2023-11-30 11:50:50 -08:00
Awni Hannun
46a39e5b1f copyright + ack 2023-11-30 11:12:53 -08:00
Jagrit Digani
e6306cfee9 jagrit's commit files 2023-11-29 10:52:08 -08:00
Angelos Katharopoulos
d1f86272a2 angelos's commit files 2023-11-29 10:42:59 -08:00
Awni Hannun
8ca7f9e8e9 awni's commit files 2023-11-29 10:30:41 -08:00