Jagrit Digani
8c2e15e6c8
Accelerate import updates for iOS ( #1227 )
...
* Update veclib and bnns includes to #include <Accelerate/Accelerate.h> for compatibility with ios
* Mark float literals in softmax.cpp to be float16_t for errors in ios
* Add arm neon vector operation guards
* Redirect to common backend for consistency
2024-06-26 09:01:50 -07:00
Awni Hannun
56c8a33439
Get metal version from xcode ( #1228 )
...
* get metal version from xcode
* typo
* fix
2024-06-26 07:02:11 -07:00
Jagrit Digani
2d6cd47713
Masked gemv ( #1211 )
2024-06-14 09:52:26 -07:00
Awni Hannun
fe3167d7ea
smaller CPU binary ( #1203 )
...
* smaller CPU binary
* fix no cpu build
2024-06-14 09:46:55 -07:00
Awni Hannun
31e134be35
Build for macOS 15 ( #1208 )
...
* Build for macos 15
* metal32 as well
* comment
---------
Co-authored-by: Awni Hannun <Awni Hannun>
2024-06-13 13:31:44 -07:00
Fangjun Kuang
f20e97b092
minor fixes ( #1194 )
...
* minor fixes
* fix build errors
2024-06-12 22:06:49 -07:00
Alex Barron
934683088e
Refactor JIT for unary/binary/ternary ops ( #1206 )
...
* refactor unary/binary/ternary ops
* get_primitive_string util
---------
2024-06-12 14:22:12 -07:00
Awni Hannun
de2b9e7d0a
Fix kernel deps to reduce build times ( #1205 )
2024-06-12 11:17:39 -07:00
Alex Barron
dd7d8e5e29
Add Quantized Ops to the JIT ( #1204 )
...
* JIT for quantized ops
* remove unused imports
* address comments
* fix imports
* second attempt to fix imports
---------
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-12 09:47:12 -07:00
Awni Hannun
df964132fb
fix scatter + test ( #1202 )
...
* fix scatter + test
* fix test warnings
* fix metal validation
2024-06-11 14:35:12 -07:00
Alex Barron
27d70c7d9d
Feature complete Metal FFT ( #1102 )
...
* feature complete metal fft
* fix contiguity bug
* jit fft
* simplify rader/bluestein constant computation
* remove kernel/utils.h dep
* remove bf16.h dep
* format
---------
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-06 12:57:25 -07:00
Awni Hannun
578842954c
fix jit scan when output doesn't have primitive ( #1190 )
2024-06-06 07:24:58 -07:00
Awni Hannun
496315fe1d
Fix scan ( #1188 )
...
* fix scan
* improve grid size
* fix cpu cummax
2024-06-05 14:21:58 -07:00
Awni Hannun
83b11bc58d
Fix Metal API validation for empty concat ( #1183 )
2024-06-04 13:17:08 -07:00
Alex Barron
375a8bbdcc
Add some internal GPU apis ( #1177 )
...
* Add unary/binary/ternay/slice/concat internal GPU ops
* add pad internal op
* formatting + no_cpu fix
2024-06-04 09:24:26 -07:00
Awni Hannun
ea9090bbc4
Add view op ( #1179 )
...
* add view primitive
* nit
* fix view
2024-06-04 08:05:27 -07:00
Alex Barron
4d485fca24
Add defines include ( #1176 )
...
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-03 09:50:10 -07:00
Brian Keene
1865299a30
Metal shaders for memory efficient self attention on large sequences ( #964 )
...
* Metal shaders for efficient self attention on large sequences
Updated fast attention: GEMM-ified with Steel primitives
Uses flash attention 1 for scale correction
* more compiler silencing
* Address rebase issues
* Templatize kernel instantiation, revise cpu bindings
* Safer writes to output
* Permit batch size > 1
* Numerical fixes for sdpa self attention
* Re-enable test, remove unused variable
* add benchmarking script
* Disable sdpa prior to perf tuning, and simplify tests for per-patch CI
2024-06-03 09:16:19 -07:00
Jagrit Digani
76b6cece46
Fix multi-block sort stride management ( #1169 )
...
* Fix multi-block sort stride management
* Add seed to tests
2024-05-31 11:10:54 -07:00
Jagrit Digani
9f0df51f8d
Fix matvec vector stride bug ( #1168 )
2024-05-29 12:18:28 -07:00
Awni Hannun
e7a2a3dcd1
Fix a couple bugs ( #1161 )
...
* fix jit reduce for RMS norm
* make strides a single buffer
* better eval error message
* fix compiling with inf and bf16
* fix cpu compile with bf16
2024-05-28 15:18:18 -07:00
Awni Hannun
7e26fd8032
Option to JIT steel gemm / conv ( #1139 )
2024-05-23 18:07:34 -07:00
Jagrit Digani
eab2685c67
Float mask update ( #1152 )
...
* Float mask update
* Update CPU impl
2024-05-23 17:20:44 -07:00
Awni Hannun
0189ab6ab6
More jitting ( #1132 )
...
* docs + circle min size build
* jit scan, arange, softmax
* add sort
* jit reductions
* remove print
* fix deps
* clean includes / nits
2024-05-23 16:23:44 -07:00
Rifur13
9401507336
Add groups to 2-D convolutions ( #1129 )
...
* Added groups to 2-D convolutions. Only implemented for **some** specializations.
Also fixed 1D grouped convs with different kernel strides and added more tests.
* fix channels condition
2024-05-22 20:01:44 -07:00
Awni Hannun
e110ca11e2
Fix offset bug for device buffers ( #1151 )
...
* fix bug with large offsets for buffers
* add a test
* remove test as its too big for small machine
2024-05-22 15:50:05 -07:00
Awni Hannun
226748b3e7
JIT compile option for binary minimization ( #1091 )
...
* try cpp 20 for compile
* unary, binary, ternary in jit
* nits
* fix gather/scatter
* fix rebase
* reorg compile
* add ternary to compile
* jit copy
* jit compile flag
* fix build
* use linked function for ternary
* some nits
* docs + circle min size build
* docs + circle min size build
* fix extension
* fix no cpu build
* improve includes
2024-05-22 12:57:13 -07:00
Awni Hannun
d568c7ee36
Rename block sparse ( #1149 )
...
* block_sparse_mm to gather_mm
* rename
* nit
* nit
2024-05-22 07:48:34 -07:00
Angelos Katharopoulos
da83f899bb
Improve qvm speed ( #1140 )
2024-05-20 09:20:44 -07:00
Awni Hannun
fb71a82ada
Fix copy bug with many dims ( #1137 )
2024-05-17 21:10:03 -07:00
Awni Hannun
23406c9e9e
Choose the right MLX bf16 for extensions ( #1135 )
...
* default to custom bf
* choose right bf
* fix extensions
* fix circle conf
2024-05-17 15:09:28 -07:00
Luca Arnaboldi
b3ec792380
Implemented Cholesky on CPU ( #1119 )
2024-05-17 12:31:59 -07:00
Angelos Katharopoulos
e78a6518fa
Block sparse qmm ( #1124 )
2024-05-16 15:24:14 -07:00
Awni Hannun
1873ffda01
Detect metal version and propagate correctly for JIT ( #1109 )
...
* detect metal version and propagate correctly for JIT
* remove softmax
* fix versions
2024-05-15 17:42:09 -07:00
Jagrit Digani
358e1fd6ab
Fused GEMM ( #1123 )
...
* Basic gemm working
* Update addmm
* Clear out steel_gemm and steel_addmm kernels
* Fuse and clear out gather gemm
* Update objc releases
2024-05-15 10:30:41 -07:00
Awni Hannun
863039da4c
Allow scatter type exception to be caught by checking in op ( #1077 )
...
* allow exception to be caught in main thread
* only for gpu
* more detailed scatter error
2024-05-13 17:43:53 -07:00
Awni Hannun
7178ac0111
No CPU option for binary minimization ( #1105 )
...
* no cpu build option
* docs
* fix
2024-05-13 16:08:11 -07:00
Max-Heinrich Laves
ff4223904d
Conv3d ( #993 )
...
* added conv3d
added conv3d
implemented explicit_gemm_conv_ND_cpu and bounds checks for slow_conv_3D
* incorporated reviewer comments
* fixed test
* reduced tensor shapes in test for conv3d
* Reviewer suggestion
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Reviewer suggestion
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Reviewer suggestion
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Reviewer suggestion
2024-05-11 06:15:02 -07:00
Alex Barron
2e158cf6d0
Add conjugate operator ( #1100 )
...
* cpu and gpu impl
* add mx.conj and array.conj()
---------
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-05-10 07:22:20 -07:00
Awni Hannun
06375e6605
Split encoders in non-concurrent context with a max ops per encoder ( #1085 )
...
* split encoders
* fix race
2024-05-09 16:21:02 -07:00
Rahul Yedida
cc05a281c4
Added ArcTan2 operation ( #1079 )
...
* Added ArcTan2 operation
* Cleanup, bug fixes from code review
* Minor cleanup, fixed Linux tests
2024-05-08 08:35:15 -07:00
Jagrit Digani
fe96ceee66
Update block offset adjustment to be in size_t ( #1087 )
2024-05-08 08:10:23 -07:00
Awni Hannun
21623156a3
Reset peak memory ( #1074 )
...
* reset peak memory
* fix linux
* nits in docs
2024-05-03 17:12:51 -07:00
Awni Hannun
b00ac960b4
change initial memory limits and add memory size to device info ( #1064 )
2024-05-03 06:50:15 -07:00
Jagrit Digani
f390957685
Block sparse mm ( #1058 )
2024-05-02 14:03:58 -07:00
Angelos Katharopoulos
17f57df797
Improvements in the quantizer and dequantization kernel ( #1061 )
2024-05-01 18:19:11 -07:00
Awni Hannun
7f7b9662ea
Fix leak for multi-output primitives which are never detached ( #1059 )
...
* fix multi output leak
* ignore arrays that will be detached
* add some comments
* stray print
2024-05-01 07:31:45 -07:00
Awni Hannun
19bef39f5c
Add a mx.metal.device_info
( #1060 )
...
* device inof
* add variant
* fix linux
* fix doc
2024-04-30 15:47:27 -07:00
Nripesh Niketan
a30e7ed2da
feat: metal formatting and pre-commit bump ( #1038 )
...
* feat: metal formatting and pre-commit bump
* add guards
* update
* more guards
* more guards
* smakk fix
* Refactor instantiation of ternary types in ternary.metal
* fix scan.metal
2024-04-30 07:18:09 -07:00
Awni Hannun
09f1777896
fix slice update indexing ( #1053 )
2024-04-29 12:17:40 -07:00
Rifur13
c4a471c99d
Add groups to Conv1d ( #948 )
...
* Add conv1d grouped convs on CPU
* Add GPU support
* Parallelize inside metal kernel
* clenaup
* Update mlx/ops.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* New unfold kernel + remove unused code
* Remove copy and refactor
* Update vjp and reuse steel gemm
* Fixed groups on cpu
* Fix metal validation
---------
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-04-27 06:24:57 -07:00
Awni Hannun
86f495985b
Add bitwise ops ( #1037 )
...
* bitwise ops
* fix tests
2024-04-26 22:03:42 -07:00
Awni Hannun
5bfe89bdb1
Cpp docs ( #1036 )
...
* start of C++ docs
* fix stream doc
* only include ops for now
2024-04-26 12:56:05 -07:00
Awni Hannun
771575d27b
Expose function to clear memory cache ( #1032 )
...
* expose function to clear memory cache
* fix linux build
* fix metal tests
2024-04-24 16:48:51 -07:00
Angelos Katharopoulos
20a01bbd9f
Simplifying and improving qmm ( #1030 )
2024-04-24 13:07:45 -07:00
Awni Hannun
3d405fb3b1
Add synchronize function ( #1006 )
...
* add synchronize function
* fix linux
* fix linux
* fix and fix docs
* fix test
* try synchronize in stream destroy
* synchronize works for both cpu and gpu
2024-04-22 08:25:46 -07:00
Jagrit Digani
85c8a91a27
Fix mask broadcasting bug and add relevant test ( #1003 )
2024-04-17 17:33:48 -07:00
Awni Hannun
8a0677d56d
Shared events for synchronization + async eval ( #998 )
...
* more async eval
* fix rebase
* try correct async eval
* fix async
* more tests for async eval
* use shared events for synchronization
* comment + cleanup
* with autorelease pool
* fix no metal build
* fix compile
* fix patch
* don't eval if asyn evale'd
* don't use is_evaled
* comments
* more multi stream tests
* try and cleanup use of is_evaled
* use a status flag
2024-04-17 06:16:02 -07:00
Jagrit Digani
b18468bf81
Masked mm ( #978 )
...
* Add block masked matmul op and primitive
2024-04-16 14:45:39 -07:00
Alex Barron
2e7c02d5cd
Metal FFT for powers of 2 up to 2048 ( #915 )
...
* add Metal FFT for powers of 2
* skip GPU test on linux
* fix contiguity bug
* address comments
* Update mlx/backend/metal/fft.cpp
* Update mlx/backend/metal/fft.cpp
* fix bug in synch
---------
Co-authored-by: Alex Barron <abarron22@apple.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-11 21:40:06 -07:00
Awni Hannun
ae18326533
No copy command encoder ( #986 )
...
* no copy command encoder
* up layer norm test tolerances
2024-04-11 21:15:36 -07:00
Nripesh Niketan
ffff671273
Update pre-commit hooks ( #984 )
2024-04-11 07:27:53 -07:00
Awni Hannun
12d4507ee3
Explicit barriers with concurrent dispatch ( #977 )
2024-04-10 21:45:31 -07:00
Awni Hannun
42afe27e12
std and expm1 ( #973 )
...
* std and expm1
* actually add expm1
* fix linux
* fix vjp
* relax tol for linux test
* Add it to the compilable primitives
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-04-08 14:26:01 -07:00
Awni Hannun
76e63212ff
Enable bfloat scan ( #974 )
...
* enable bfloat scan
* fix tests
2024-04-08 12:29:19 -07:00
Awni Hannun
aac2f9fb61
Improve profiling with gpu tracing ( #969 )
...
* improve profiling with gpu tracing
* fix for linux
* nit
* doc fix
* fix example
2024-04-07 21:47:43 -07:00
Awni Hannun
d88d2124b5
segfaut layer norm grad ( #955 )
2024-04-04 10:59:15 -07:00
Awni Hannun
e142aaf8a1
Option for precise softmax ( #953 )
...
* precise softmax
* Add an equivalency check
* Make the threadgroup memory definition fixed
* precise cpu softmax
* precise option on cpu
* remove print
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-04-04 08:32:35 -07:00
Awni Hannun
2427fa171e
Fix cpu compile ( #934 )
...
* fix one cpu bug, test for another
* format hooks
* simplify contiguity check for cpu compile
* fix
* add back donation
* comment
2024-04-01 17:37:12 -07:00
Angelos Katharopoulos
110d9b149d
Layer norm grad fix donation bug ( #941 )
...
* add layer norm grad test
* Fix donation bug in layernorm vjp
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-01 06:15:50 -07:00
Angelos Katharopoulos
9cbff5ec1d
Fix typo in qmm check ( #940 )
2024-03-31 19:15:44 -07:00
Awni Hannun
8915901966
Donation bug ( #933 )
...
* donation
* buf
* fix bug in softmax
* comment
* remove print
2024-03-30 10:08:54 -07:00
Cheng
913b19329c
Add missing && when forwarding args ( #925 )
...
Without the && args would be copied and perfect forwarding won't work.
2024-03-29 06:48:29 -07:00
Angelos Katharopoulos
5f9ba3019f
Fix qmm_t for unaligned cases ( #923 )
2024-03-28 15:34:57 -07:00
Jack Mousseau
45f636e759
Add Metal debug option and capture functions ( #707 )
...
* Add Metal debug option and capture functions
* Add brief Metal debugger documentation
* doc nits
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-28 09:40:31 -07:00
Angelos Katharopoulos
aca7584635
Fix OOB read in qmv when non-divisible by blocksize ( #917 )
2024-03-27 22:18:35 -07:00
Angelos Katharopoulos
29221fa238
Implement vjps for some primitives in the fast namespace ( #883 )
...
* Implement rope vjp in terms of rope
* RMSNormVJP primitive and kernel
* Add LayerNormVJP primitive and kernel
2024-03-26 16:35:34 -07:00
Jagrit Digani
925014b661
Fix multiblock sort limits ( #906 )
...
* Fix multiblock sort limits
* Fix metal validation error
2024-03-26 14:00:00 -07:00
Angelos Katharopoulos
9948eddf11
Fix nan and improve speed for qvm ( #903 )
2024-03-26 10:41:45 -07:00
Cheng
28fcd2b519
Add missing && when forwarding args ( #894 )
...
Without the && args would be copied and perfect forwarding won't work.
Also add template utils to make sure the function only forwards array
and not vector<array>.
2024-03-25 14:55:54 -07:00
Daniel Strobusch
479051ce1c
add numeric type hierarchy and issubdtype as well as a set_dtype meth… ( #427 )
...
* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate
numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)
).
Closes #285 .
* nits in docs
* unify type category checking
* nits in docs
* nits in docs
* more docs nits
* fix callable type
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-25 12:32:59 -07:00
Angelos Katharopoulos
6ee1112f30
Fix copy donation and add partial rope ( #881 )
2024-03-22 17:28:26 -07:00
Angelos Katharopoulos
2225374060
Adds mx.fast.layer_norm ( #870 )
2024-03-21 13:55:51 -07:00
nicolov
105d236889
Add vmap for SVD and inverse ( #849 )
2024-03-21 13:18:27 -07:00
Awni Hannun
a54f06b16f
Fast RMS Norm ( #862 )
...
* fast rmsnorm
* no rms gpu
* kernel
* fix shared mem
* looped rms and donation in softmax
* Make the squaring in float32 to avoid underflow
* Fix the default StreamOrDevice for rope and rms_norm in fast
* nits
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-03-21 07:20:54 -07:00
Jagrit Digani
b219d12a6b
Check edge case handling in row reduce med kernel ( #858 )
2024-03-20 11:37:58 -07:00
Jagrit Digani
cec8661113
Add a SliceUpdate op and primitive ( #850 )
...
* Enable copy to work with int64 strides
* Fix uniform buffer indices or copy kernel arguments
* Update utils.h
* Remove manual unrolling of elem to loc loop
* GPU copy updated to handle negative strides
* Add slice update primitive
2024-03-20 10:39:25 -07:00
Cheng
d39ed54f8e
Some C++ code are not needed ( #841 )
...
1. Anonymous namespace means internal linkage, static keyword is not needed.
2. The default constructor of std::shared_ptr initializes the pointer to
nullptr, you don't need to explicitly set it.
2024-03-18 17:04:10 -07:00
Awni Hannun
16546c70d8
No reshape rope ( #838 )
...
* no reshape rope
* no reshape rope
2024-03-18 17:03:07 -07:00
nicolov
eaba55c9bf
Add matrix inversion primitive ( #822 )
2024-03-15 06:34:36 -07:00
Jagrit Digani
8dfc376c00
Strided reduce specialization for small reductions ( #826 )
...
* Add small column / general reduction specialization
2024-03-14 09:16:53 -07:00
Angelos Katharopoulos
1efee9db09
Add types and order in kernel name ( #831 )
2024-03-13 20:34:06 -07:00
Angelos Katharopoulos
3f8b1668c4
Make reshape faster for row_contiguous cases ( #829 )
2024-03-13 16:22:03 -07:00
Angelos Katharopoulos
76c919b4ec
NumberOfElements for shapeless compile and vmap fixes ( #802 )
2024-03-13 10:34:14 -07:00
Angelos Katharopoulos
29d0c10ee5
Reshape improvement ( #818 )
2024-03-12 17:54:31 -07:00
Jagrit Digani
5ad133f8bb
No copy gems ( #801 )
...
* Enable collapsing batch dims in gemm
* Update gemm to only make copies when neither of the last 2 axes are contiguous
* Update addmm to support gemv shapes
* Update addmm to support irregular batch strides
* Update tests
2024-03-12 13:13:41 -07:00
nicolov
d0c544a868
Add SVD primitive ( #809 )
...
Add SVD op using Accelerate's LAPACK following
https://developer.apple.com/documentation/accelerate/
compressing_an_image_using_linear_algebra
Co-authored-by: Nicolo Valigi <nvaligi@apple.com>
2024-03-12 12:30:11 -07:00
Awni Hannun
8b7532b9ab
fix scatter ( #821 )
2024-03-12 11:42:07 -07:00
nicolov
0ae22b915b
Remove code duplication in reduce ops ( #793 )
...
* Remove code duplication in reduce ops
* Remove the unnecessary lambda
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-03-11 10:57:07 -07:00
Awni Hannun
7c441600fe
Compile stride bug ( #812 )
...
* fix compile stride bug
* revert sdpa fix
* fix cpu
* fix bug with simplifying outputs
2024-03-11 06:31:31 -07:00
Jagrit Digani
ec8a4864fa
Fix SDPA kernel bug on Mac OS 13.3 SDK ( #805 )
...
* Move sdpa kernel to allocate tgp mem statically and allow macOS 13.3 SDK builds
* Style
2024-03-07 10:18:09 -08:00
Awni Hannun
f512b905c7
Minimum xcode / sdk ( #800 )
...
* minimum xcode /sdk
* try multiple xcode versions in CI
* update python
* metal validation for python tests
2024-03-07 08:19:43 -08:00
Angelos Katharopoulos
14b4e51a7c
Improved quantized matrix vector product ( #786 )
2024-03-05 17:32:19 -08:00
Awni Hannun
cbcf44a4ca
Some fixes in cache / thread safety ( #777 )
...
* some fixes in cache / thread safety
* speed up no cache case
* fix opt test
* optimizer docs
* otpimizer docs
* fix adafactor
* fix adafactor
2024-03-05 13:30:50 -08:00
Brian Keene
0787724c44
Fast Inference SDPA op ( #735 )
...
* Fast Inference SDPA op
Implements metal shaders for:
o = mx.fast_inference_sdpa(queries, keys, values, scale, mask)
Supports fp16, fp32 dtypes; assumes d_k = 128.
Generic op support / prompt encoding supported via mlx primitives.
Metal implementation is for the inference use case only.
Majority of performance benefits appears to results from GQA & reduced
bandwidth requirements; there is approximate performance parity for the
MHA use case (from some measurements on M3 Max).
* Flush shared memory to zero before unprotected reads for (scores @ values)
* Move to fast:: namespace, address reviewer comments
... also attempt to revert formatter auto-change for files not relevant
to this change
* Shared memory flush to top of kernel
* Resolve compiler warnings
* Update python/src/fast.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/src/fast.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/src/fast.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/src/fast.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update docstring per PR feedback
* Softmax in higher precision, ...
* route to fallback for more use cases - batch size > 1, head_dim other
than 128, etc.
* Address linux build failure
* Address other reviewer comments
* Remove extraneous eval_cpu function per review
---------
Co-authored-by: Atila Orhon <64497909+atiorh@users.noreply.github.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: atila <atiorh@icloud.com>
2024-03-04 21:06:11 -08:00
Awni Hannun
7b463ffb07
Ios compile ( #784 )
...
* try to fix build for ios
* skip cpu compile
* fix namespace
* fix namespace
* Use CMake for platform specific cpu compile
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-03-04 20:02:26 -08:00
Jagrit Digani
6686e61ca4
Reduce update ( #783 )
...
* Split reduction files to reduce compile times
* Add small and medium axis size specializations for row reductions
* Add non-row-reduction options for small and med kernels
2024-03-04 19:09:51 -08:00
Awni Hannun
d5964a2710
bindings for memory info ( #761 )
...
* bindings for memory info
* update api
* keep cache low if requested
* fix default
* nit in ops error
2024-03-01 19:51:58 -08:00
Jagrit Digani
776c3d226d
Convolution update ( #651 )
...
* Init steel conv and update Conv primitive
* Update slow CPU implementation to support flipping and input dilation winograd conv routing
Co-authored-by: Awni Hannun <awni@apple.com>
2024-02-28 20:11:16 -08:00
Awni Hannun
f5f18b704f
fix temporary bug ( #752 )
2024-02-27 17:44:39 -08:00
Awni Hannun
56ba3ec40e
fix cpu compile on older OS ( #747 )
2024-02-26 22:20:53 -08:00
Awni Hannun
e6418781ab
Fix logsumexp edge case ( #740 )
...
* fix logsumexp
* fix inf constant
* also fix power grad
* fix ternary dispatch
2024-02-25 08:39:55 -08:00
Awni Hannun
ac02cf33bd
Fix some issues using MLX in C++ ( #739 )
...
* fix preamble build
* fix some issues with using MLX as a dep in C++
2024-02-24 22:20:57 -08:00
Rifur13
126c9869c8
Implement the 'where' primitive for conditional selection ( #664 )
2024-02-22 15:10:48 -08:00
Jagrit Digani
884b4ed43b
Fix threadgroup memory in arg reduce ( #723 )
2024-02-21 19:42:16 -08:00
Vijay Krish
972d9a3aea
Up to 10x faster scatter. ( #709 )
...
* Faster scatter.
Add specialization for 1-d index tensors.
* Address review comments.
- Check for row contiguity of index, update tensors
instead of checking strides.
- Add support for 1d specialization with col contiguous update
tensor, along with a test.
* Nit1
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Nit2
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
---------
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-02-21 11:09:30 -08:00
Awni Hannun
5798256fcf
Shapeless compilation for some graphs ( #687 )
...
* shapeless compilation for some graphs
* update compile benchmark
* default compile a few activations
* buffer donation
* bugfix
* shapeless fix
* update tests to work for cpu and gpu fusion
* test kwargs
* add kwargs to compile
* Recompile when python arguments change
* no compile for tanh
* some constant tests
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-19 21:43:54 -08:00
Awni Hannun
1a4f4c5ea6
Refactor CPU compile preamble ( #708 )
...
* refactor cpu preamble
* fix include order
* fix some issues'
* fixes for linux
* try to fix includes
* add back warning suppression
* more linux fixes
2024-02-19 06:12:53 -08:00
Jack Mousseau
0925af43b0
Remove unused variables ( #706 )
2024-02-18 12:50:10 -08:00
Awni Hannun
dc937b8ed3
CPU compile ( #691 )
...
* build and load shared object for cpu compile
* nits
* cpu compile tests pass
* cpu compile tests pass
* fix preamble for g++
* donation
* fix gpu buffer donation
* reuse prebuilt libraries
* faster contiguity conditoins
* fix test
* rid compiler warning
* fast erf
* Fix float16 for compile and add more types to cpu compile
* Remove a forgotten comment
* use cached libs
* nits
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-17 06:54:32 -08:00
Awni Hannun
c3965fc5ee
Separate fast ops and primitives ( #699 )
2024-02-16 19:16:39 -08:00
Awni Hannun
ccf1645995
Custom primitive + RoPE fat op ( #676 )
...
* extensions start
* rope custom op
* fix build
* docs + rope benchmark
* fix test
* Add a Metal kernel for RoPE
* Fix position of traditional
* transform tests
* Move rope computation to float and fix tests
* Fix the test and a typo
* change to fast
* fix no metal build
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-14 14:04:25 -08:00
Jagrit Digani
1a48713d32
Update gather and scatter to not use Argument Encoder ( #683 )
...
* Replace argument encoder usage for gather and scatter
* Use constant address space for shapes and strides
* Split gather and scatter to improve compile times
* Enable the GPU tests
* Update the CI config
* Fix scatter dispatch for scalar indices
* Remove arg encoder utils
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-14 13:42:13 -08:00
Vijay Krish
2fdc2462c3
Faster gather and scatter. ( #682 )
...
Reduce unnecessary integer ops, especially since
there kernels are integer bound.
Increase number of iterations for benchmarks for
better smoothing.
Github Issue #506
Co-authored-by: Vijay Krishnamoorthy <vijay_krish@apple.com>
2024-02-13 17:47:41 -08:00
Angelos Katharopoulos
40c108766b
Quantized matmul fix ( #677 )
...
* Fix qmv for small or unaligned matrices
* Fix qmm
2024-02-12 18:54:21 -08:00
Awni Hannun
3756381358
Faster bfloat quantized mat-vec and vec-mat ( #663 )
2024-02-11 21:53:16 -08:00
Awni Hannun
d12573daa6
quote file name ( #670 )
2024-02-11 10:33:30 -08:00
Vijay Krish
06072601ce
Scatter optimization : Eliminate 64b integer divide. ( #662 )
...
Launch 2D grid to eliminate divide and mod in device code,
since 64b integer division is very expensive.
Github Issue #506
Co-authored-by: Vijay Krishnamoorthy <vijay_krish@apple.com>
2024-02-10 08:49:51 -08:00
Awni Hannun
7f3f8d8f8d
Fix the softmax fix ( #661 )
2024-02-09 17:02:13 -08:00
Awni Hannun
b96be943dc
bug fix ( #658 )
2024-02-09 16:50:45 -08:00
Abdussamet Türker
b670485185
Remainder negative numerator bug fixed ( #641 )
...
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-09 16:49:14 -08:00
Angelos Katharopoulos
28eac18571
Kernel generation ( #614 )
...
Generate reusable element-wise kernels given a computation graph.
2024-02-07 13:15:59 -08:00
Jagrit Digani
316ff490b3
Remove masks from BlockLoader and clear out load case for invalid thread ( #634 )
2024-02-05 16:00:17 -08:00
Awni Hannun
d40a04f8dc
minor fixes ( #631 )
...
* minor fixes
* var with ddof >= nelements
2024-02-05 13:27:49 -08:00
Awni Hannun
d75ae52ecd
Compile primitive ( #571 )
...
* Compiled primitive with basic binary, unary graph-level fusion
2024-02-05 06:51:22 -08:00
Awni Hannun
e319383ef9
Faster gather ( #626 )
...
* faster gather
* update copyright
2024-02-04 17:25:44 -08:00
David Koski
ebfd3618b0
fixes for building and running on iOS ( #619 )
...
* fixes for building and running on iOS
* per suggestion just use Accelerate
2024-02-04 12:29:17 -08:00
Awni Hannun
cb6156d35d
Fix eval in trace bugs ( #612 )
...
* Fix eval in trace bugs
* comment nit
2024-02-02 09:57:12 -08:00
Vijay Krish
fcc5ac1c64
Add GPU support for uint64/int64 reductions ( #569 )
2024-01-31 11:18:04 -08:00
Angelos Katharopoulos
0de5988f92
Custom VJP and checkpointing ( #541 )
...
* Implement custom_vjp and checkpointing
* Add a dependency management primitive
* Change the eval order to deep branches first
* Add graph depth tracking to the array
2024-01-30 16:04:45 -08:00
Jagrit Digani
375446453e
Update Compute Pipeline Creation API ( #581 )
...
* Add option to specialize metal functions on function constants
* Update Compute Pipeline Creation API
* Add options to make libraries from source and stitching
* Update function specialization name options
2024-01-30 15:42:36 -08:00
Angelos Katharopoulos
1895d34c20
Fix log1p with inf inputs ( #592 )
2024-01-30 14:02:50 -08:00
Angelos Katharopoulos
65d0b8df9f
Fix binary op dispatch ( #584 )
2024-01-29 19:36:17 -08:00
Awni Hannun
3c2f192345
Propagate nans in binary ops ( #579 )
...
* propagate nans in binary ops
* handle empty matmul
* cpu minimum/maximum propagate nan
* benchmark maximum
* add min as well
* throw on negative indices with full
* verbose on linux
* fix matmul for zero K
2024-01-29 11:19:38 -08:00
Awni Hannun
8993382aaa
Buffer Donation ( #519 )
...
* buffer donation
* fix to move shared pointer
* format
* gpu in place for copy and binary
* revert ops test
* cpu in place
* a little cleanup
* remove useless bench
2024-01-26 16:30:33 -08:00
taher
077c1ee64a
QR factorization ( #310 )
...
* add qr factorization
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-26 09:27:31 -08:00
Jagrit Digani
6d3bee3364
Fix oob reads in gemv kernel ( #523 )
2024-01-22 12:06:04 -08:00
Awni Hannun
7a34e46677
Quantize with groups of 32 ( #511 )
...
* allow quantize with group sizes of 32
* missing cpu dispatch
* remove print
* Fix qvm for group_size 32
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-01-21 06:19:05 -08:00
Awni Hannun
3d99a8d31d
Fix format / build ( #489 )
2024-01-18 10:01:59 -08:00
Ethan
a749a91c75
Support disable metal buffer cache to prevent performance degradation caused by large memory caching ( #390 )
...
* support disable metal buffer cache, due to large unused memory buffered when llm generated long context tokens
* Run format and add "cache_enabled" feature tests
2024-01-18 08:33:34 -08:00
Angelos Katharopoulos
90c234b7ac
Fix round to round half-cases to even ( #482 )
2024-01-17 15:27:23 -08:00
Angelos Katharopoulos
135fd796d2
Fix detach for multi-output primitives ( #480 )
2024-01-17 14:08:07 -08:00
Jagrit Digani
78102a47ad
Update GEMM ( #424 )
...
* Organize and collect metal subroutine templates and elements in `metal/kernels/steel/`
* Update gemm elements for better performance
* Add split-K specialization for gemm
* Add `addmm` primitive, op and bindings for fused matmul and bias addition
* Update tests and benchmarks as needed
2024-01-17 12:42:39 -08:00
Awni Hannun
275db7221a
Command buffer reports errors ( #479 )
...
* command buffer reports errors
* typo
* simplify
2024-01-17 11:53:30 -08:00
Angelos Katharopoulos
d8fabaa12b
Split multi output ( #461 )
...
* Multi-output split primitive
* Add the multi-output split to the ArrayIterator
* Add some grad tests for split
2024-01-16 13:33:55 -08:00
Awni Hannun
a2ffea683a
Fix eye for larger matrices ( #463 )
...
* fix eye
* fix scatter for <32bit (non native atomic) types
* fix int overflow
2024-01-16 00:51:24 -08:00
Angelos Katharopoulos
c15fe3e61b
Allow arbitrary first dimension in quantization kernels. ( #458 )
...
* Allow arbitrary first dim on qmm_t and qmv
* Allow arbitrary first dim on qmm and qvm
* Specialized aligned vs unaligned case
* Add more checks for valid quantizations
2024-01-16 00:46:21 -08:00
Awni Hannun
c9934fe8a4
Metal validation ( #432 )
...
* tests clear metal validation
* add cpp test with metal validation to circleci
* nit
2024-01-11 11:57:24 -08:00
Awni Hannun
f099ebe535
Multi output primitives ( #330 )
...
* Multi-output primitives
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-01-08 16:39:08 -08:00
Nripesh Niketan
73321b8097
feat: add logicalAnd and logicalOR ( #386 )
...
* feat: add logicalAnd and logicalOR
* run pre-commit
* Refactor logical_and and logical_or functions
* Add acknowledgement
* Add logical AND and logical OR operators
* Refactor logical_and and logical_or functions
* Add support for logical operators on bool arrays
* Update mlx/ops.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update mlx/ops.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Add logical AND and OR operators for arrays and scalars
* Refactor vjp and jvp methods in primitives.cpp
* Add overloaded operators for logical AND and OR
* format
---------
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-08 07:00:05 -08:00
Angelos Katharopoulos
a611b0bc82
Removes the retain_graph
flag ( #385 )
...
* Adds global tracing flag
* Removes retain_graph in favor of is_tracer
2024-01-07 15:16:51 -08:00
Awni Hannun
c6d2878c1a
safely divide for 0 size inputs ( #388 )
2024-01-07 00:19:54 -08:00
Awni Hannun
b9e415d19c
bump pre commit and fix format ( #373 )
2024-01-04 16:28:52 -08:00
davidkoski
c82a8cc526
move all ObjC (via metal-cpp) interaction until post static initializers ( #370 )
...
* move all ObjC (via metal-cpp) interaction until post static initializers
- metal-cpp relies on static initializers to cache class and selector pointers
- code in mlx was using metal-cpp to set up NSAutoreleasePools during its own static init time
- but this code was silently failing as the class and selector pointers from metal-cpp were still nil
- defer the creation of NSAutoreleasePools until after static init time
- ensure that we have coverage where autorelease pools are needed
* Update device.cpp
remove commented code
* Update device.cpp
remove commented out code
* Update scheduler.h
update comment
* per discussion use the pool inside the task() -- this will be metal only, not needed for cpu
* Update allocator.cpp
move pool to release/alloc area
2024-01-04 16:12:00 -08:00
Angelos Katharopoulos
e7f5059fe4
Support for quantized matmul with w and w^T ( #349 )
...
* Add the metal qvm implementation
* Add qmm_n
* Add gradient wrt to input for quantized_matmul
2024-01-03 14:22:36 -08:00
Awni Hannun
99c80a2c8b
Memory allocation ( #292 )
...
* try alternative gc
* try no cache
* add forced swap
* remove cache for now
* add cache back
* change fit crtieria
* remove unused function
* nit in comment
* tune / fix allocation
* increase block limit to original
2024-01-02 11:59:19 -08:00
Josh Soref
44c1ce5e6a
Spelling ( #342 )
...
* spelling: accumulates
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: across
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: additional
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: against
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: among
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: array
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: at least
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: available
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: axes
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: basically
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: bfloat
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: bounds
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: broadcast
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: buffer
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: class
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: coefficients
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: collision
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: combinations
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: committing
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: computation
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: consider
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: constructing
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: conversions
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: correctly
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: corresponding
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: declaration
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: default
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: dependency
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: destination
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: destructor
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: dimensions
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: divided
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: element-wise
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: elements
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: endianness
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: equivalent
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: explicitly
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: github
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: indices
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: irregularly
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: memory
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: metallib
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: negative
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: notable
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: optional
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: otherwise
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: overridden
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: partially
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: partition
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: perform
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: perturbations
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: positively
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: primitive
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: repeat
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: repeats
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: respect
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: respectively
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: result
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: rounding
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: separate
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: skipping
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: structure
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: the
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: transpose
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: unnecessary
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: unneeded
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: unsupported
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
---------
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
2024-01-01 21:08:17 -08:00
Diogo
1f6ab6a556
Safetensor support ( #215 )
...
Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-27 02:06:55 -08:00
Angelos Katharopoulos
9e6b8c9f48
Refactor the reduction kernels ( #277 )
2023-12-24 14:47:57 -08:00
Awni Hannun
8b227fa9af
fix no metal build ( #276 )
2023-12-23 19:18:10 -08:00
Ronan Collobert
cd3616a463
Revisit autorelease memory pools ( #260 )
...
* make general autorelease pool part of metal device
* make things simpler
* no metal backend support
* new_memory_pool -> new_scoped_memory_pool
2023-12-22 11:01:26 -08:00
Awni Hannun
2118c3dbfa
fix ( #255 )
2023-12-21 18:18:41 -08:00
Awni Hannun
a002797d52
A temporary fix ( #254 )
2023-12-21 17:59:15 -08:00
Daniel Strobusch
794feb83df
support arange for bfloat16 ( #245 )
2023-12-21 14:33:43 -08:00
Angelos Katharopoulos
b3916cbf2b
Improve names of quantization arguments ( #235 )
...
* Change the default quantization group_size to 64
* Rename groups to group_size and width to bits
2023-12-20 16:53:53 -08:00
Angelos Katharopoulos
2807c6aff0
Implements divide for integer types and adds floor_divide op ( #228 )
...
* Add floor_divide
* Add floor_divide to the tests
* Add floor_divide to the docs
2023-12-19 20:12:19 -08:00
davidkoski
de892cb66c
fix for non-macos build issue on cblas.h ( #227 )
2023-12-19 17:01:59 -08:00
davidkoski
37024d899c
fixes for building with swiftpm ( #225 )
...
- clbas is part of veclib (compile failure)
- add SWIFTPM_BUNDLE #define to allow loading the metallib from a swiftpm resource bundle
2023-12-19 16:22:10 -08:00
Angelos Katharopoulos
dfa9f4bc58
An initial quantized matmul implementation ( #205 )
...
* Add quantized matvec
* Add quantized matrix matrix with 2nd matrix transposed
* Add quantized matmul tests
* Add a slow cpu quantized matmul
* Add a slightly faster vectorized cpu version
2023-12-18 23:18:57 -08:00
Angelos Katharopoulos
4d4af12c6f
Adds round op and primitive ( #203 )
2023-12-18 11:32:48 -08:00
Ronan Collobert
83f266c44c
Lazy metal_device_ initialization ( #185 )
...
This ensures it is defined when the Scheduler needs it.
2023-12-15 16:06:46 -08:00
Luca Arnaboldi
b93c4cf378
Floor and Ceil ( #150 )
...
* Implements Floor and Ceil Ops
2023-12-14 10:00:23 -08:00
Ikko Eltociear Ashimine
c3272d4917
Update conv.cpp ( #145 )
...
Peform -> Perform
2023-12-12 11:27:49 -08:00
Awni Hannun
71d1fff90a
Bug fix in metal binary kernel dispatch for large arrays ( #125 )
...
* bug fix
* format
2023-12-10 16:12:31 -08:00
Angelos Katharopoulos
600db7d754
Fix build on Xcode 14 ( #116 )
...
* Fix build on Xcode 14
* Style fixes
2023-12-10 06:58:52 -08:00
Angelos Katharopoulos
2b714714e1
Add the remainder op ( #85 )
...
* Add remainder in the C++ backend
* Add the python binding and test
2023-12-08 15:08:52 -08:00
Angelos Katharopoulos
209404239b
Fix the accelerate dispatch for the power op ( #70 )
...
- The exponent and base were swapped because accelerate is using
exponent-base instead of base-exponent
- Fix also the test for binary ops as it was testing op(x, x) which
couldn't catch ordering errors like that
2023-12-08 10:58:03 -08:00
Jagrit Digani
d518b3b6a5
Fix gemv broadcasting bug ( #6 )
...
* Fix broadcasting bug in gemv
* Add relevant tests in test_blas.py
2023-12-05 14:15:43 -08:00
Awni Hannun
db487e6b1a
format
2023-11-30 11:50:50 -08:00
Awni Hannun
46a39e5b1f
copyright + ack
2023-11-30 11:12:53 -08:00
Jagrit Digani
e6306cfee9
jagrit's commit files
2023-11-29 10:52:08 -08:00
Angelos Katharopoulos
d1f86272a2
angelos's commit files
2023-11-29 10:42:59 -08:00
Awni Hannun
8ca7f9e8e9
awni's commit files
2023-11-29 10:30:41 -08:00