MLX
Loading...
Searching...
No Matches
primitives.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <unordered_set>
6
7#include "mlx/array.h"
8#include "mlx/device.h"
9#include "mlx/io/load.h"
10#include "mlx/stream.h"
11
12#define DEFINE_VMAP() \
13 virtual std::pair<std::vector<array>, std::vector<int>> vmap( \
14 const std::vector<array>& inputs, const std::vector<int>& axes) \
15 override;
16
17#define DEFINE_GRADS() \
18 std::vector<array> jvp( \
19 const std::vector<array>& primals, \
20 const std::vector<array>& tangents, \
21 const std::vector<int>& argnums) override; \
22 \
23 std::vector<array> vjp( \
24 const std::vector<array>& primals, \
25 const std::vector<array>& cotangents, \
26 const std::vector<int>& argnums, \
27 const std::vector<array>& outputs) override;
28
29#define DEFINE_PRINT(PRIMITIVE) \
30 void print(std::ostream& os) override { \
31 os << #PRIMITIVE; \
32 }
33
34#define DEFINE_DEFAULT_IS_EQUIVALENT() \
35 bool is_equivalent(const Primitive& other) const override { \
36 return true; \
37 }
38
39#define DEFINE_INPUT_OUTPUT_SHAPE() \
40 std::vector<std::vector<int>> output_shapes( \
41 const std::vector<array>& inputs) override { \
42 return {inputs[0].shape()}; \
43 }
44
45namespace mlx::core {
46
47// Abstract base class
48class Primitive {
49 public:
50 explicit Primitive(Stream stream) : stream_(stream) {}
51
53 const Device& device() {
54 return stream().device;
55 }
56
58 const Stream& stream() {
59 return stream_;
60 }
61
69 virtual void eval_cpu(
70 const std::vector<array>& inputs,
71 std::vector<array>& outputs) = 0;
72 virtual void eval_gpu(
73 const std::vector<array>& inputs,
74 std::vector<array>& outputs) = 0;
75
79 virtual std::vector<array> jvp(
80 const std::vector<array>& primals,
81 const std::vector<array>& tangents,
82 const std::vector<int>& argnums);
83
87 virtual std::vector<array> vjp(
88 const std::vector<array>& primals,
89 const std::vector<array>& cotangents,
90 const std::vector<int>& argnums,
91 const std::vector<array>& outputs);
92
99 virtual std::pair<std::vector<array>, std::vector<int>> vmap(
100 const std::vector<array>& inputs,
101 const std::vector<int>& axes);
102
104 virtual void print(std::ostream& os) = 0;
105
107 virtual bool is_equivalent(const Primitive& other) const {
108 return false;
109 }
110
113 virtual std::vector<std::vector<int>> output_shapes(
114 const std::vector<array>& inputs);
115
116 virtual ~Primitive() = default;
117 Primitive(const Primitive& other) = delete;
118 Primitive(Primitive&& other) = delete;
119 Primitive& operator=(const Primitive& other) = delete;
120 Primitive& operator=(Primitive&& other) = delete;
121
122 private:
123 // Every primitive stores the stream it should run in
124 Stream stream_;
125};
126
127class UnaryPrimitive : public Primitive {
131 public:
133
134 virtual void eval_cpu(const std::vector<array>& inputs, array& output) = 0;
135 virtual void eval_gpu(const std::vector<array>& inputs, array& output) = 0;
136
137 inline void eval_cpu(
138 const std::vector<array>& inputs,
139 std::vector<array>& outputs) override {
140 eval_cpu(inputs, outputs[0]);
141 }
142 inline void eval_gpu(
143 const std::vector<array>& inputs,
144 std::vector<array>& outputs) override {
145 eval_gpu(inputs, outputs[0]);
146 }
147
148 virtual ~UnaryPrimitive() = default;
149 UnaryPrimitive(const UnaryPrimitive& other) = delete;
151 UnaryPrimitive& operator=(const UnaryPrimitive& other) = delete;
153};
154
155class Abs : public UnaryPrimitive {
156 public:
158
159 void eval_cpu(const std::vector<array>& inputs, array& out) override;
160 void eval_gpu(const std::vector<array>& inputs, array& out) override;
161
167
168 private:
169 void eval(const std::vector<array>& inputs, array& out);
170};
171
172class Add : public UnaryPrimitive {
173 public:
175
176 void eval_cpu(const std::vector<array>& inputs, array& out) override;
177 void eval_gpu(const std::vector<array>& inputs, array& out) override;
178
184
185 private:
186 void eval(const std::vector<array>& inputs, array& out);
187};
188
189class AddMM : public UnaryPrimitive {
190 public:
191 explicit AddMM(Stream stream, float alpha, float beta)
192 : UnaryPrimitive(stream), alpha_(alpha), beta_(beta) {}
193
194 void eval_cpu(const std::vector<array>& inputs, array& out) override;
195 void eval_gpu(const std::vector<array>& inputs, array& out) override;
196
197 std::vector<array> vjp(
198 const std::vector<array>& primals,
199 const std::vector<array>& cotangents,
200 const std::vector<int>& argnums,
201 const std::vector<array>& outputs) override;
202
205
206 bool is_equivalent(const Primitive& other) const override;
207
208 private:
209 const float alpha_;
210 const float beta_;
211};
212
213class Arange : public UnaryPrimitive {
214 public:
215 explicit Arange(Stream stream, double start, double stop, double step)
216 : UnaryPrimitive(stream), start_(start), stop_(stop), step_(step) {}
217
218 void eval_cpu(const std::vector<array>& inputs, array& out) override;
219 void eval_gpu(const std::vector<array>& inputs, array& out) override;
220
222 bool is_equivalent(const Primitive& other) const override;
223
224 private:
225 double start_;
226 double stop_;
227 double step_;
228
229 void eval(const std::vector<array>& inputs, array& out);
230};
231
232class ArcCos : public UnaryPrimitive {
233 public:
235
236 void eval_cpu(const std::vector<array>& inputs, array& out) override;
237 void eval_gpu(const std::vector<array>& inputs, array& out) override;
238
244
245 private:
246 void eval(const std::vector<array>& inputs, array& out);
247};
248
249class ArcCosh : public UnaryPrimitive {
250 public:
252
253 void eval_cpu(const std::vector<array>& inputs, array& out) override;
254 void eval_gpu(const std::vector<array>& inputs, array& out) override;
255
261
262 private:
263 void eval(const std::vector<array>& inputs, array& out);
264};
265
266class ArcSin : public UnaryPrimitive {
267 public:
269
270 void eval_cpu(const std::vector<array>& inputs, array& out) override;
271 void eval_gpu(const std::vector<array>& inputs, array& out) override;
272
278
279 private:
280 void eval(const std::vector<array>& inputs, array& out);
281};
282
283class ArcSinh : public UnaryPrimitive {
284 public:
286
287 void eval_cpu(const std::vector<array>& inputs, array& out) override;
288 void eval_gpu(const std::vector<array>& inputs, array& out) override;
289
295
296 private:
297 void eval(const std::vector<array>& inputs, array& out);
298};
299
300class ArcTan : public UnaryPrimitive {
301 public:
303
304 void eval_cpu(const std::vector<array>& inputs, array& out) override;
305 void eval_gpu(const std::vector<array>& inputs, array& out) override;
306
312
313 private:
314 void eval(const std::vector<array>& inputs, array& out);
315};
316
317class ArcTan2 : public UnaryPrimitive {
318 public:
320
321 void eval_cpu(const std::vector<array>& inputs, array& out) override;
322 void eval_gpu(const std::vector<array>& inputs, array& out) override;
323
329
330 private:
331 void eval(const std::vector<array>& inputs, array& out);
332};
333
334class ArcTanh : public UnaryPrimitive {
335 public:
337
338 void eval_cpu(const std::vector<array>& inputs, array& out) override;
339 void eval_gpu(const std::vector<array>& inputs, array& out) override;
340
346
347 private:
348 void eval(const std::vector<array>& inputs, array& out);
349};
350
352 public:
353 explicit ArgPartition(Stream stream, int kth, int axis)
354 : UnaryPrimitive(stream), kth_(kth), axis_(axis) {}
355
356 void eval_cpu(const std::vector<array>& inputs, array& out) override;
357 void eval_gpu(const std::vector<array>& inputs, array& out) override;
358
363 bool is_equivalent(const Primitive& other) const override;
364
365 private:
366 int kth_;
367 int axis_;
368
369 void eval(const std::vector<array>& inputs, array& out);
370};
371
372class ArgReduce : public UnaryPrimitive {
373 public:
378
379 explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis)
380 : UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {}
381
382 void eval_cpu(const std::vector<array>& inputs, array& out) override;
383 void eval_gpu(const std::vector<array>& inputs, array& out) override;
384
388 bool is_equivalent(const Primitive& other) const override;
389 std::vector<std::vector<int>> output_shapes(
390 const std::vector<array>& inputs) override;
391
392 private:
393 ReduceType reduce_type_;
394 int axis_;
395
396 void eval(const std::vector<array>& inputs, array& out);
397};
398
399class ArgSort : public UnaryPrimitive {
400 public:
401 explicit ArgSort(Stream stream, int axis)
402 : UnaryPrimitive(stream), axis_(axis) {}
403
404 void eval_cpu(const std::vector<array>& inputs, array& out) override;
405 void eval_gpu(const std::vector<array>& inputs, array& out) override;
406
410 bool is_equivalent(const Primitive& other) const override;
411
412 private:
413 int axis_;
414
415 void eval(const std::vector<array>& inputs, array& out);
416};
417
418class AsType : public UnaryPrimitive {
419 public:
420 explicit AsType(Stream stream, Dtype dtype)
421 : UnaryPrimitive(stream), dtype_(dtype) {}
422
423 void eval_cpu(const std::vector<array>& inputs, array& out) override;
424 void eval_gpu(const std::vector<array>& inputs, array& out) override;
425
430 bool is_equivalent(const Primitive& other) const override;
431
432 private:
433 Dtype dtype_;
434
435 void eval(const std::vector<array>& inputs, array& out);
436};
437
438class AsStrided : public UnaryPrimitive {
439 public:
440 explicit AsStrided(
442 std::vector<int> shape,
443 std::vector<size_t> strides,
444 size_t offset)
446 shape_(std::move(shape)),
447 strides_(std::move(strides)),
448 offset_(offset) {}
449
450 void eval_cpu(const std::vector<array>& inputs, array& out) override;
451 void eval_gpu(const std::vector<array>& inputs, array& out) override;
452
455 bool is_equivalent(const Primitive& other) const override;
456
457 private:
458 std::vector<int> shape_;
459 std::vector<size_t> strides_;
460 size_t offset_;
461
462 void eval(const std::vector<array>& inputs, array& out);
463};
464
466 public:
467 enum Op { And, Or, Xor, LeftShift, RightShift };
468
471
472 void eval_cpu(const std::vector<array>& inputs, array& out) override;
473 void eval_gpu(const std::vector<array>& inputs, array& out) override;
474
477 bool is_equivalent(const Primitive& other) const override;
478 void print(std::ostream& os) override;
480
481 private:
482 Op op_;
483};
484
486 public:
487 explicit BlockMaskedMM(Stream stream, int block_size)
488 : UnaryPrimitive(stream), block_size_(block_size) {}
489
490 void eval_cpu(const std::vector<array>& inputs, array& out) override;
491 void eval_gpu(const std::vector<array>& inputs, array& out) override;
492
493 std::vector<array> vjp(
494 const std::vector<array>& primals,
495 const std::vector<array>& cotangents,
496 const std::vector<int>& argnums,
497 const std::vector<array>& outputs) override;
498
500 bool is_equivalent(const Primitive& other) const override;
501
502 private:
503 int block_size_;
504
505 void eval(const std::vector<array>& inputs, array& out);
506};
507
508class GatherMM : public UnaryPrimitive {
509 public:
511
512 void eval_cpu(const std::vector<array>& inputs, array& out) override;
513 void eval_gpu(const std::vector<array>& inputs, array& out) override;
514
515 std::vector<array> vjp(
516 const std::vector<array>& primals,
517 const std::vector<array>& cotangents,
518 const std::vector<int>& argnums,
519 const std::vector<array>& outputs) override;
520
523
524 private:
525 void eval(const std::vector<array>& inputs, array& out);
526};
527
528class Broadcast : public UnaryPrimitive {
529 public:
530 explicit Broadcast(Stream stream, const std::vector<int>& shape)
531 : UnaryPrimitive(stream), shape_(shape) {}
532
533 void eval_cpu(const std::vector<array>& inputs, array& out) override;
534 void eval_gpu(const std::vector<array>& inputs, array& out) override;
535
539 bool is_equivalent(const Primitive& other) const override;
540
541 private:
542 std::vector<int> shape_;
543
544 void eval(const std::vector<array>& inputs, array& out);
545};
546
547class Ceil : public UnaryPrimitive {
548 public:
550
551 void eval_cpu(const std::vector<array>& inputs, array& out) override;
552 void eval_gpu(const std::vector<array>& inputs, array& out) override;
553
559
560 private:
561 void eval(const std::vector<array>& inputs, array& out);
562};
563
564class Compiled : public Primitive {
565 public:
566 /*
567 * The inputs, outputs and tape are either tracers or constants.
568 * - The tape should not contain the inputs, but it should contain the
569 * outputs.
570 * - The tape should also have only one array per primitive for multi-output
571 * primitives.
572 * - The constant_ids contains ids of arrays in the input list that are safe
573 * to treat as scalar constants.
574 */
575 explicit Compiled(
577 std::vector<array> inputs,
578 std::vector<array> outputs,
579 std::vector<array> tape,
580 std::unordered_set<uintptr_t> constant_ids);
581
582 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
583 override;
584 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
585 override;
586
589 std::vector<std::vector<int>> output_shapes(
590 const std::vector<array>& inputs) override;
591 void print(std::ostream& os) override;
592 bool is_equivalent(const Primitive& other) const override;
593
594 std::string lib_name() const {
595 return kernel_lib_;
596 }
597
598 private:
599 const std::vector<array> inputs_;
600 const std::vector<array> outputs_;
601 const std::vector<array> tape_;
602 const std::unordered_set<uintptr_t> constant_ids_;
603
604 std::string kernel_lib_;
605};
606
608 public:
609 explicit Concatenate(Stream stream, int axis)
610 : UnaryPrimitive(stream), axis_(axis) {}
611
612 void eval_cpu(const std::vector<array>& inputs, array& out) override;
613 void eval_gpu(const std::vector<array>& inputs, array& out) override;
614
618 bool is_equivalent(const Primitive& other) const override;
619
620 private:
621 int axis_;
622
623 void eval(const std::vector<array>& inputs, array& out);
624};
625
626class Conjugate : public UnaryPrimitive {
627 public:
628 explicit Conjugate(Stream stream) : UnaryPrimitive(stream) {}
629
630 void eval_cpu(const std::vector<array>& inputs, array& out) override;
631 void eval_gpu(const std::vector<array>& inputs, array& out) override;
632
637
638 private:
639 void eval(const std::vector<array>& inputs, array& out);
640};
641
643 public:
644 explicit Convolution(
645 Stream stream,
646 const std::vector<int>& kernel_strides,
647 const std::vector<int>& padding,
648 const std::vector<int>& kernel_dilation,
649 const std::vector<int>& input_dilation,
650 const int groups = 1,
651 const bool flip = false)
652 : UnaryPrimitive(stream),
653 padding_(padding),
654 kernel_strides_(kernel_strides),
655 kernel_dilation_(kernel_dilation),
656 input_dilation_(input_dilation),
657 groups_(groups),
658 flip_(flip) {}
659
660 void eval_cpu(const std::vector<array>& inputs, array& out) override;
661 void eval_gpu(const std::vector<array>& inputs, array& out) override;
662
663 std::vector<array> vjp(
664 const std::vector<array>& primals,
665 const std::vector<array>& cotangents,
666 const std::vector<int>& argnums,
667 const std::vector<array>& outputs) override;
668
670 bool is_equivalent(const Primitive& other) const override;
671
672 private:
673 std::vector<int> padding_;
674 std::vector<int> kernel_strides_;
675 std::vector<int> kernel_dilation_;
676 std::vector<int> input_dilation_;
677 int groups_;
678 bool flip_;
679
680 void eval(const std::vector<array>& inputs, array& out);
681};
682
683class Copy : public UnaryPrimitive {
684 public:
685 explicit Copy(Stream stream) : UnaryPrimitive(stream) {}
686
687 void eval_cpu(const std::vector<array>& inputs, array& out) override;
688 void eval_gpu(const std::vector<array>& inputs, array& out) override;
689
695
696 private:
697 void eval(const std::vector<array>& inputs, array& out);
698};
699
700class Cos : public UnaryPrimitive {
701 public:
702 explicit Cos(Stream stream) : UnaryPrimitive(stream) {}
703
704 void eval_cpu(const std::vector<array>& inputs, array& out) override;
705 void eval_gpu(const std::vector<array>& inputs, array& out) override;
706
712
713 private:
714 void eval(const std::vector<array>& inputs, array& out);
715};
716
717class Cosh : public UnaryPrimitive {
718 public:
719 explicit Cosh(Stream stream) : UnaryPrimitive(stream) {}
720
721 void eval_cpu(const std::vector<array>& inputs, array& out) override;
722 void eval_gpu(const std::vector<array>& inputs, array& out) override;
723
729
730 private:
731 void eval(const std::vector<array>& inputs, array& out);
732};
733
735 public:
737 Stream stream,
738 int num_outputs,
739 std::function<std::vector<array>(
740 const std::vector<array>&,
741 const std::vector<array>&,
742 const std::vector<array>&)> vjp,
743 std::function<std::vector<array>(
744 const std::vector<array>&,
745 const std::vector<array>&,
746 const std::vector<int>&)> jvp,
747 std::function<std::pair<std::vector<array>, std::vector<int>>(
748 const std::vector<array>&,
749 const std::vector<int>&)> vmap)
750 : Primitive(stream),
751 num_outputs_(num_outputs),
752 vjp_fun_(std::move(vjp)),
753 jvp_fun_(std::move(jvp)),
754 vmap_fun_(std::move(vmap)) {}
755
756 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
757 override;
758 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
759 override;
760
764
765 private:
766 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
767
768 int num_outputs_;
769
770 std::function<std::vector<array>(
771 const std::vector<array>&,
772 const std::vector<array>&,
773 const std::vector<array>&)>
774 vjp_fun_;
775 std::function<std::vector<array>(
776 const std::vector<array>&,
777 const std::vector<array>&,
778 const std::vector<int>&)>
779 jvp_fun_;
780 std::function<std::pair<std::vector<array>, std::vector<int>>(
781 const std::vector<array>&,
782 const std::vector<int>&)>
783 vmap_fun_;
784};
785
786class Depends : public Primitive {
787 public:
788 explicit Depends(Stream stream) : Primitive(stream) {}
789
790 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
791 override;
792 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
793 override;
794
795 std::vector<array> vjp(
796 const std::vector<array>& primals,
797 const std::vector<array>& cotan,
798 const std::vector<int>& argnums,
799 const std::vector<array>& outputs) override;
800
802
803 private:
804 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
805};
806
807class Divide : public UnaryPrimitive {
808 public:
809 explicit Divide(Stream stream) : UnaryPrimitive(stream) {}
810
811 void eval_cpu(const std::vector<array>& inputs, array& out) override;
812 void eval_gpu(const std::vector<array>& inputs, array& out) override;
813
819
820 private:
821 void eval(const std::vector<array>& inputs, array& out);
822};
823
824class DivMod : public Primitive {
825 public:
826 explicit DivMod(Stream stream) : Primitive(stream) {}
827
828 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
829 override;
830 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
831 override;
832
837 std::vector<std::vector<int>> output_shapes(
838 const std::vector<array>& inputs) override {
839 return std::vector{inputs[0].shape(), inputs[0].shape()};
840 }
841
842 private:
843 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
844};
845
846class Select : public UnaryPrimitive {
847 public:
848 explicit Select(Stream stream) : UnaryPrimitive(stream) {}
849
850 void eval_cpu(const std::vector<array>& inputs, array& out) override;
851 void eval_gpu(const std::vector<array>& inputs, array& out) override;
852
858
859 private:
860 void eval(const std::vector<array>& inputs, array& out);
861};
862
863class Remainder : public UnaryPrimitive {
864 public:
865 explicit Remainder(Stream stream) : UnaryPrimitive(stream) {}
866
867 void eval_cpu(const std::vector<array>& inputs, array& out) override;
868 void eval_gpu(const std::vector<array>& inputs, array& out) override;
869
875
876 private:
877 void eval(const std::vector<array>& inputs, array& out);
878};
879
880class Equal : public UnaryPrimitive {
881 public:
882 explicit Equal(Stream stream, bool equal_nan = false)
883 : UnaryPrimitive(stream), equal_nan_(equal_nan) {}
884
885 void eval_cpu(const std::vector<array>& inputs, array& out) override;
886 void eval_gpu(const std::vector<array>& inputs, array& out) override;
887
892
893 void print(std::ostream& os) override {
894 if (equal_nan_) {
895 os << "NaNEqual";
896 } else {
897 os << "Equal";
898 }
899 }
900
901 private:
902 void eval(const std::vector<array>& inputs, array& out);
903 bool equal_nan_;
904};
905
906class Erf : public UnaryPrimitive {
907 public:
908 explicit Erf(Stream stream) : UnaryPrimitive(stream) {}
909
910 void eval_cpu(const std::vector<array>& inputs, array& out) override;
911 void eval_gpu(const std::vector<array>& inputs, array& out) override;
912
918
919 private:
920 void eval(const std::vector<array>& inputs, array& out);
921};
922
923class ErfInv : public UnaryPrimitive {
924 public:
925 explicit ErfInv(Stream stream) : UnaryPrimitive(stream) {}
926
927 void eval_cpu(const std::vector<array>& inputs, array& out) override;
928 void eval_gpu(const std::vector<array>& inputs, array& out) override;
929
935
936 private:
937 void eval(const std::vector<array>& inputs, array& out);
938};
939
940class Exp : public UnaryPrimitive {
941 public:
942 explicit Exp(Stream stream) : UnaryPrimitive(stream) {}
943
944 void eval_cpu(const std::vector<array>& inputs, array& out) override;
945 void eval_gpu(const std::vector<array>& inputs, array& out) override;
946
952
953 private:
954 void eval(const std::vector<array>& inputs, array& out);
955};
956
957class Expm1 : public UnaryPrimitive {
958 public:
959 explicit Expm1(Stream stream) : UnaryPrimitive(stream) {}
960
961 void eval_cpu(const std::vector<array>& inputs, array& out) override;
962 void eval_gpu(const std::vector<array>& inputs, array& out) override;
963
968
969 private:
970 void eval(const std::vector<array>& inputs, array& out);
971};
972
973class FFT : public UnaryPrimitive {
974 public:
975 explicit FFT(
976 Stream stream,
977 const std::vector<size_t>& axes,
978 bool inverse,
979 bool real)
980 : UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {}
981
982 void eval_cpu(const std::vector<array>& inputs, array& out) override;
983 void eval_gpu(const std::vector<array>& inputs, array& out) override;
984
988
989 bool is_equivalent(const Primitive& other) const override;
990
991 private:
992 std::vector<size_t> axes_;
993 bool inverse_;
994 bool real_;
995
996 void eval(const std::vector<array>& inputs, array& out);
997};
998
999class Floor : public UnaryPrimitive {
1000 public:
1001 explicit Floor(Stream stream) : UnaryPrimitive(stream) {}
1002
1003 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1004 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1005
1011
1012 private:
1013 void eval(const std::vector<array>& inputs, array& out);
1014};
1015
1016class Full : public UnaryPrimitive {
1017 public:
1018 explicit Full(Stream stream) : UnaryPrimitive(stream) {}
1019
1020 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1021 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1022
1027
1028 private:
1029 void eval(const std::vector<array>& inputs, array& out);
1030};
1031
1032class Gather : public UnaryPrimitive {
1033 public:
1034 explicit Gather(
1035 Stream stream,
1036 const std::vector<int>& axes,
1037 const std::vector<int>& slice_sizes)
1038 : UnaryPrimitive(stream), axes_(axes), slice_sizes_(slice_sizes) {}
1039
1040 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1041 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1042
1046 bool is_equivalent(const Primitive& other) const override;
1047
1048 private:
1049 void eval(const std::vector<array>& inputs, array& out);
1050 std::vector<int> axes_;
1051 std::vector<int> slice_sizes_;
1052};
1053
1054class Greater : public UnaryPrimitive {
1055 public:
1056 explicit Greater(Stream stream) : UnaryPrimitive(stream) {}
1057
1058 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1059 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1060
1066
1067 private:
1068 void eval(const std::vector<array>& inputs, array& out);
1069};
1070
1072 public:
1073 explicit GreaterEqual(Stream stream) : UnaryPrimitive(stream) {}
1074
1075 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1076 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1077
1083
1084 private:
1085 void eval(const std::vector<array>& inputs, array& out);
1086};
1087
1088class Hadamard : public UnaryPrimitive {
1089 public:
1090 explicit Hadamard(Stream stream, float scale)
1091 : UnaryPrimitive(stream), scale_(scale) {}
1092
1093 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1094 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1095
1100
1101 bool is_equivalent(const Primitive& other) const override;
1102
1103 private:
1104 float scale_;
1105
1106 void eval(const std::vector<array>& inputs, array& out);
1107};
1108
1109class Less : public UnaryPrimitive {
1110 public:
1111 explicit Less(Stream stream) : UnaryPrimitive(stream) {}
1112
1113 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1114 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1115
1121
1122 private:
1123 void eval(const std::vector<array>& inputs, array& out);
1124};
1125
1127 public:
1128 explicit LessEqual(Stream stream) : UnaryPrimitive(stream) {}
1129
1130 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1131 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1132
1138
1139 private:
1140 void eval(const std::vector<array>& inputs, array& out);
1141};
1142
1143class Load : public UnaryPrimitive {
1144 public:
1145 explicit Load(
1146 Stream stream,
1147 std::shared_ptr<io::Reader> reader,
1148 size_t offset,
1149 bool swap_endianness = false)
1150 : UnaryPrimitive(stream),
1151 reader_(std::move(reader)),
1152 offset_(offset),
1153 swap_endianness_(swap_endianness) {
1154 if (stream.device == Device::gpu) {
1155 io_stream();
1156 }
1157 }
1158
1159 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1160 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1161
1163
1164 private:
1165 Stream& io_stream() {
1166 static Stream io_stream = new_stream(Device::cpu);
1167 return io_stream;
1168 };
1169 void eval(const std::vector<array>& inputs, array& out);
1170 std::shared_ptr<io::Reader> reader_;
1171 size_t offset_;
1172 bool swap_endianness_;
1173};
1174
1175class Log : public UnaryPrimitive {
1176 public:
1177 enum Base { two, ten, e };
1178
1179 explicit Log(Stream stream, Base base)
1180 : UnaryPrimitive(stream), base_(base) {}
1181
1182 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1183 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1184
1189
1190 void print(std::ostream& os) override {
1191 switch (base_) {
1192 case e:
1193 os << "Log";
1194 break;
1195 case two:
1196 os << "Log2";
1197 break;
1198 case ten:
1199 os << "Log10";
1200 break;
1201 }
1202 }
1203
1204 private:
1205 Base base_;
1206 void eval(const std::vector<array>& inputs, array& out);
1207};
1208
1209class Log1p : public UnaryPrimitive {
1210 public:
1211 explicit Log1p(Stream stream) : UnaryPrimitive(stream) {}
1212
1213 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1214 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1215
1220
1221 private:
1222 void eval(const std::vector<array>& inputs, array& out);
1223};
1224
1226 public:
1227 explicit LogicalNot(Stream stream) : UnaryPrimitive(stream) {}
1228
1229 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1230 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1231
1237
1238 private:
1239 void eval(const std::vector<array>& inputs, array& out);
1240};
1241
1243 public:
1244 explicit LogicalAnd(Stream stream) : UnaryPrimitive(stream) {}
1245
1246 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1247 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1248
1254
1255 private:
1256 void eval(const std::vector<array>& inputs, array& out);
1257};
1258
1260 public:
1261 explicit LogicalOr(Stream stream) : UnaryPrimitive(stream) {}
1262
1263 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1264 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1265
1271
1272 private:
1273 void eval(const std::vector<array>& inputs, array& out);
1274};
1275
1277 public:
1278 explicit LogAddExp(Stream stream) : UnaryPrimitive(stream) {}
1279
1280 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1281 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1282
1288
1289 private:
1290 void eval(const std::vector<array>& inputs, array& out);
1291};
1292
1293class Matmul : public UnaryPrimitive {
1294 public:
1295 explicit Matmul(Stream stream) : UnaryPrimitive(stream) {}
1296
1297 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1298 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1299
1300 std::vector<array> vjp(
1301 const std::vector<array>& primals,
1302 const std::vector<array>& cotangents,
1303 const std::vector<int>& argnums,
1304 const std::vector<array>& outputs) override;
1305
1309};
1310
1311class Maximum : public UnaryPrimitive {
1312 public:
1313 explicit Maximum(Stream stream) : UnaryPrimitive(stream) {}
1314
1315 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1316 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1317
1323
1324 private:
1325 void eval(const std::vector<array>& inputs, array& out);
1326};
1327
1328class Minimum : public UnaryPrimitive {
1329 public:
1330 explicit Minimum(Stream stream) : UnaryPrimitive(stream) {}
1331
1332 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1333 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1334
1340
1341 private:
1342 void eval(const std::vector<array>& inputs, array& out);
1343};
1344
1345class Multiply : public UnaryPrimitive {
1346 public:
1347 explicit Multiply(Stream stream) : UnaryPrimitive(stream) {}
1348
1349 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1350 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1351
1357
1358 private:
1359 void eval(const std::vector<array>& inputs, array& out);
1360};
1361
1362class Negative : public UnaryPrimitive {
1363 public:
1364 explicit Negative(Stream stream) : UnaryPrimitive(stream) {}
1365
1366 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1367 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1368
1374
1375 private:
1376 void eval(const std::vector<array>& inputs, array& out);
1377};
1378
1379class NotEqual : public UnaryPrimitive {
1380 public:
1381 explicit NotEqual(Stream stream) : UnaryPrimitive(stream) {}
1382
1383 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1384 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1385
1391
1392 private:
1393 void eval(const std::vector<array>& inputs, array& out);
1394};
1395
1397 public:
1399 Stream stream,
1400 std::vector<int> axes,
1401 bool inverted,
1402 Dtype dtype)
1403 : UnaryPrimitive(stream),
1404 axes_(std::move(axes)),
1405 inverted_(inverted),
1406 dtype_(dtype) {}
1407
1408 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1409 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1410
1413 bool is_equivalent(const Primitive& other) const override;
1414 std::vector<std::vector<int>> output_shapes(
1415 const std::vector<array>& inputs) override {
1416 return {{}};
1417 }
1418
1419 private:
1420 std::vector<int> axes_;
1421 bool inverted_;
1422 Dtype dtype_;
1423
1424 void eval(const std::vector<array>& inputs, array& out);
1425};
1426
1427class Pad : public UnaryPrimitive {
1428 public:
1429 explicit Pad(
1430 Stream stream,
1431 const std::vector<int>& axes,
1432 const std::vector<int>& low_pad_size,
1433 const std::vector<int>& high_pad_size)
1434 : UnaryPrimitive(stream),
1435 axes_(axes),
1436 low_pad_size_(low_pad_size),
1437 high_pad_size_(high_pad_size) {}
1438
1439 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1440 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1441
1445 bool is_equivalent(const Primitive& other) const override;
1446
1447 private:
1448 std::vector<int> axes_;
1449 std::vector<int> low_pad_size_;
1450 std::vector<int> high_pad_size_;
1451
1452 void eval(const std::vector<array>& inputs, array& out);
1453};
1454
1456 public:
1457 explicit Partition(Stream stream, int kth, int axis)
1458 : UnaryPrimitive(stream), kth_(kth), axis_(axis) {}
1459
1460 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1461 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1462
1467 bool is_equivalent(const Primitive& other) const override;
1468
1469 private:
1470 int kth_;
1471 int axis_;
1472
1473 void eval(const std::vector<array>& inputs, array& out);
1474};
1475
1476class Power : public UnaryPrimitive {
1477 public:
1478 explicit Power(Stream stream) : UnaryPrimitive(stream) {}
1479
1480 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1481 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1482
1488
1489 private:
1490 void eval(const std::vector<array>& inputs, array& out);
1491};
1492
1494 public:
1496 Stream stream,
1497 int group_size,
1498 int bits,
1499 bool transpose)
1500 : UnaryPrimitive(stream),
1501 group_size_(group_size),
1502 bits_(bits),
1503 transpose_(transpose) {}
1504
1505 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1506 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1507
1511 bool is_equivalent(const Primitive& other) const override;
1512
1513 private:
1514 int group_size_;
1515 int bits_;
1516 bool transpose_;
1517
1518 void eval(const std::vector<array>& inputs, array& out);
1519};
1520
1522 public:
1523 explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose)
1524 : UnaryPrimitive(stream),
1525 group_size_(group_size),
1526 bits_(bits),
1527 transpose_(transpose) {}
1528
1529 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1530 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1531
1535 bool is_equivalent(const Primitive& other) const override;
1536
1537 private:
1538 int group_size_;
1539 int bits_;
1540 bool transpose_;
1541
1542 void eval(const std::vector<array>& inputs, array& out);
1543};
1544
1546 public:
1547 explicit RandomBits(Stream stream, const std::vector<int>& shape, int width)
1548 : UnaryPrimitive(stream), shape_(shape), width_(width) {}
1549
1550 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1551 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1552
1555 bool is_equivalent(const Primitive& other) const override;
1556
1557 private:
1558 std::vector<int> shape_;
1559 int width_;
1560
1561 void eval(const std::vector<array>& inputs, array& out);
1562};
1563
1564class Reshape : public UnaryPrimitive {
1565 public:
1566 explicit Reshape(Stream stream, const std::vector<int>& shape)
1567 : UnaryPrimitive(stream), shape_(shape) {}
1568
1569 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1570 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1571
1575 bool is_equivalent(const Primitive& other) const override;
1576
1577 private:
1578 std::vector<int> shape_;
1579
1580 void eval(const std::vector<array>& inputs, array& out);
1581
1582 std::pair<bool, std::vector<size_t>> prepare_reshape(
1583 const array& in,
1584 const array& out);
1585 void shared_buffer_reshape(
1586 const array& in,
1587 const std::vector<size_t>& out_strides,
1588 array& out);
1589};
1590
1591class Reduce : public UnaryPrimitive {
1592 public:
1593 enum ReduceType { And, Or, Sum, Prod, Min, Max };
1594
1595 explicit Reduce(
1596 Stream stream,
1597 ReduceType reduce_type,
1598 const std::vector<int>& axes)
1599 : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}
1600
1601 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1602 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1603
1605
1606 std::vector<array> vjp(
1607 const std::vector<array>& primals,
1608 const std::vector<array>& cotangents,
1609 const std::vector<int>& argnums,
1610 const std::vector<array>& outputs) override;
1611
1612 std::vector<std::vector<int>> output_shapes(
1613 const std::vector<array>& inputs) override;
1614
1615 void print(std::ostream& os) override {
1616 switch (reduce_type_) {
1617 case And:
1618 os << "And";
1619 break;
1620 case Or:
1621 os << "Or";
1622 break;
1623 case Sum:
1624 os << "Sum";
1625 break;
1626 case Prod:
1627 os << "Prod";
1628 break;
1629 case Min:
1630 os << "Min";
1631 break;
1632 case Max:
1633 os << "Max";
1634 break;
1635 }
1636 }
1637 bool is_equivalent(const Primitive& other) const override;
1638
1639 private:
1640 ReduceType reduce_type_;
1641 std::vector<int> axes_;
1642
1643 void eval(const std::vector<array>& inputs, array& out);
1644};
1645
1646class Round : public UnaryPrimitive {
1647 public:
1648 explicit Round(Stream stream) : UnaryPrimitive(stream) {}
1649
1650 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1651 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1652
1658
1659 private:
1660 void eval(const std::vector<array>& inputs, array& out);
1661};
1662
1663class Scan : public UnaryPrimitive {
1664 public:
1666
1667 explicit Scan(
1668 Stream stream,
1669 ReduceType reduce_type,
1670 int axis,
1671 bool reverse,
1672 bool inclusive)
1673 : UnaryPrimitive(stream),
1674 reduce_type_(reduce_type),
1675 axis_(axis),
1676 reverse_(reverse),
1677 inclusive_(inclusive) {}
1678
1679 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1680 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1681
1684
1685 void print(std::ostream& os) override {
1686 os << "Cum";
1687 switch (reduce_type_) {
1688 case Sum:
1689 os << "Sum";
1690 break;
1691 case Prod:
1692 os << "Prod";
1693 break;
1694 case Min:
1695 os << "Min";
1696 break;
1697 case Max:
1698 os << "Max";
1699 break;
1700 }
1701 }
1702 bool is_equivalent(const Primitive& other) const override;
1703
1704 private:
1705 ReduceType reduce_type_;
1706 int axis_;
1707 bool reverse_;
1708 bool inclusive_;
1709
1710 void eval(const std::vector<array>& inputs, array& out);
1711};
1712
1713class Scatter : public UnaryPrimitive {
1714 public:
1716
1717 explicit Scatter(
1718 Stream stream,
1719 ReduceType reduce_type,
1720 const std::vector<int>& axes)
1721 : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}
1722
1723 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1724 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1725
1728
1729 void print(std::ostream& os) override {
1730 os << "Scatter";
1731 switch (reduce_type_) {
1732 case Sum:
1733 os << " Sum";
1734 break;
1735 case Prod:
1736 os << " Prod";
1737 break;
1738 case Min:
1739 os << " Min";
1740 break;
1741 case Max:
1742 os << " Max";
1743 break;
1744 case None:
1745 break;
1746 }
1747 }
1748 bool is_equivalent(const Primitive& other) const override;
1749
1750 private:
1751 void eval(const std::vector<array>& inputs, array& out);
1752 ReduceType reduce_type_;
1753 std::vector<int> axes_;
1754};
1755
1756class Sigmoid : public UnaryPrimitive {
1757 public:
1758 explicit Sigmoid(Stream stream) : UnaryPrimitive(stream) {}
1759
1760 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1761 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1762
1768
1769 private:
1770 void eval(const std::vector<array>& inputs, array& out);
1771};
1772
1773class Sign : public UnaryPrimitive {
1774 public:
1775 explicit Sign(Stream stream) : UnaryPrimitive(stream) {}
1776
1777 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1778 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1779
1785
1786 private:
1787 void eval(const std::vector<array>& inputs, array& out);
1788};
1789
1790class Sin : public UnaryPrimitive {
1791 public:
1792 explicit Sin(Stream stream) : UnaryPrimitive(stream) {}
1793
1794 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1795 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1796
1802
1803 private:
1804 void eval(const std::vector<array>& inputs, array& out);
1805};
1806
1807class Sinh : public UnaryPrimitive {
1808 public:
1809 explicit Sinh(Stream stream) : UnaryPrimitive(stream) {}
1810
1811 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1812 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1813
1819
1820 private:
1821 void eval(const std::vector<array>& inputs, array& out);
1822};
1823
1824class Slice : public UnaryPrimitive {
1825 public:
1826 explicit Slice(
1827 Stream stream,
1828 const std::vector<int>& start_indices,
1829 const std::vector<int>& end_indices,
1830 const std::vector<int>& strides)
1831 : UnaryPrimitive(stream),
1832 start_indices_(start_indices),
1833 end_indices_(end_indices),
1834 strides_(strides) {}
1835
1836 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1837 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1838
1842 bool is_equivalent(const Primitive& other) const override;
1843
1844 private:
1845 std::vector<int> start_indices_;
1846 std::vector<int> end_indices_;
1847 std::vector<int> strides_;
1848
1849 void eval(const std::vector<array>& inputs, array& out);
1850};
1851
1853 public:
1854 explicit SliceUpdate(
1855 Stream stream,
1856 const std::vector<int>& start_indices,
1857 const std::vector<int>& end_indices,
1858 const std::vector<int>& strides)
1859 : UnaryPrimitive(stream),
1860 start_indices_(start_indices),
1861 end_indices_(end_indices),
1862 strides_(strides) {}
1863
1864 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1865 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1866
1870 bool is_equivalent(const Primitive& other) const override;
1871
1872 private:
1873 std::vector<int> start_indices_;
1874 std::vector<int> end_indices_;
1875 std::vector<int> strides_;
1876
1877 void eval(const std::vector<array>& inputs, array& out);
1878
1879 std::tuple<int64_t, std::vector<int64_t>> prepare_slice(const array& in);
1880};
1881
1882class Softmax : public UnaryPrimitive {
1883 public:
1884 explicit Softmax(Stream stream, bool precise)
1885 : UnaryPrimitive(stream), precise_(precise) {}
1886
1887 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1888 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1889
1894
1895 bool is_equivalent(const Primitive& other) const override;
1896
1897 private:
1898 void eval(const std::vector<array>& inputs, array& out);
1899 bool precise_;
1900};
1901
1902class Sort : public UnaryPrimitive {
1903 public:
1904 explicit Sort(Stream stream, int axis)
1905 : UnaryPrimitive(stream), axis_(axis) {}
1906
1907 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1908 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1909
1914 bool is_equivalent(const Primitive& other) const override;
1915
1916 private:
1917 int axis_;
1918
1919 void eval(const std::vector<array>& inputs, array& out);
1920};
1921
1922class Split : public Primitive {
1923 public:
1924 explicit Split(Stream stream, const std::vector<int>& indices, int axis)
1925 : Primitive(stream), indices_(indices), axis_(axis) {}
1926
1927 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
1928 override;
1929 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
1930 override;
1931
1935 bool is_equivalent(const Primitive& other) const override;
1936
1937 private:
1938 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
1939
1940 std::vector<int> indices_;
1941 int axis_;
1942};
1943
1944class Square : public UnaryPrimitive {
1945 public:
1946 explicit Square(Stream stream) : UnaryPrimitive(stream) {}
1947
1948 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1949 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1950
1956
1957 private:
1958 void eval(const std::vector<array>& inputs, array& out);
1959};
1960
1961class Sqrt : public UnaryPrimitive {
1962 public:
1963 explicit Sqrt(Stream stream, bool recip = false)
1964 : UnaryPrimitive(stream), recip_(recip) {}
1965
1966 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1967 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1968
1972 bool is_equivalent(const Primitive& other) const override;
1973
1974 void print(std::ostream& os) override {
1975 if (recip_) {
1976 os << "Rsqrt";
1977 } else {
1978 os << "Sqrt";
1979 }
1980 }
1981
1982 private:
1983 void eval(const std::vector<array>& inputs, array& out);
1984 bool recip_;
1985};
1986
1988 public:
1989 explicit StopGradient(Stream stream) : UnaryPrimitive(stream) {}
1990
1991 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1992 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1993
1998
1999 private:
2000 void eval(const std::vector<array>& inputs, array& out);
2001};
2002
2003class Subtract : public UnaryPrimitive {
2004 public:
2005 explicit Subtract(Stream stream) : UnaryPrimitive(stream) {}
2006
2007 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2008 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2009
2015
2016 private:
2017 void eval(const std::vector<array>& inputs, array& out);
2018};
2019
2020class Tan : public UnaryPrimitive {
2021 public:
2022 explicit Tan(Stream stream) : UnaryPrimitive(stream) {}
2023
2024 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2025 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2026
2032
2033 private:
2034 void eval(const std::vector<array>& inputs, array& out);
2035};
2036
2037class Tanh : public UnaryPrimitive {
2038 public:
2039 explicit Tanh(Stream stream) : UnaryPrimitive(stream) {}
2040
2041 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2042 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2043
2049
2050 private:
2051 void eval(const std::vector<array>& inputs, array& out);
2052};
2053
2054class Uniform : public UnaryPrimitive {
2055 public:
2056 explicit Uniform(Stream stream) : UnaryPrimitive(stream) {}
2057
2058 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2059 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2060
2064
2065 private:
2066 void eval(const std::vector<array>& inputs, array& out);
2067};
2068
2069class View : public UnaryPrimitive {
2070 public:
2071 explicit View(Stream stream, Dtype dtype)
2072 : UnaryPrimitive(stream), dtype_(dtype) {}
2073
2074 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2075 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2076
2078 void print(std::ostream& os) override;
2079 bool is_equivalent(const Primitive& other) const override;
2080
2081 private:
2082 Dtype dtype_;
2083};
2084
2086 public:
2087 explicit Transpose(Stream stream, const std::vector<int>& axes)
2088 : UnaryPrimitive(stream), axes_(axes) {}
2089
2090 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2091 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2092
2096 bool is_equivalent(const Primitive& other) const override;
2097
2098 private:
2099 std::vector<int> axes_;
2100
2101 void eval(const std::vector<array>& inputs, array& out);
2102};
2103
2104/* QR Factorization primitive. */
2105class QRF : public Primitive {
2106 public:
2107 explicit QRF(Stream stream) : Primitive(stream) {}
2108
2109 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2110 override;
2111 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2112 override;
2113
2115
2116 private:
2117 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
2118};
2119
2120/* SVD primitive. */
2121class SVD : public Primitive {
2122 public:
2123 explicit SVD(Stream stream) : Primitive(stream) {}
2124
2125 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2126 override;
2127 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2128 override;
2129
2132
2133 private:
2134 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
2135};
2136
2137/* Matrix inversion primitive. */
2138class Inverse : public UnaryPrimitive {
2139 public:
2140 explicit Inverse(Stream stream, bool tri, bool upper)
2141 : UnaryPrimitive(stream), tri_(tri), upper_(upper) {}
2142
2143 void eval_cpu(const std::vector<array>& inputs, array& output) override;
2144 void eval_gpu(const std::vector<array>& inputs, array& output) override;
2145
2148
2149 private:
2150 void eval(const std::vector<array>& inputs, array& output);
2151 bool tri_;
2152 bool upper_;
2153};
2154
2155class Cholesky : public UnaryPrimitive {
2156 public:
2157 explicit Cholesky(Stream stream, bool upper)
2158 : UnaryPrimitive(stream), upper_(upper) {}
2159
2160 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2161 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2162
2165
2166 private:
2167 void eval(const std::vector<array>& inputs, array& output);
2168 bool upper_;
2169};
2170
2171} // namespace mlx::core
Definition primitives.h:155
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Abs(Stream stream)
Definition primitives.h:157
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:164
std::vector< std::vector< int > > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:166
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:165
Definition primitives.h:172
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Add(Stream stream)
Definition primitives.h:174
Definition primitives.h:189
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
AddMM(Stream stream, float alpha, float beta)
Definition primitives.h:191
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
Definition primitives.h:213
Arange(Stream stream, double start, double stop, double step)
Definition primitives.h:215
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:232
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCos(Stream stream)
Definition primitives.h:234
Definition primitives.h:249
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCosh(Stream stream)
Definition primitives.h:251
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:266
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcSin(Stream stream)
Definition primitives.h:268
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:283
ArcSinh(Stream stream)
Definition primitives.h:285
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:317
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTan2(Stream stream)
Definition primitives.h:319
Definition primitives.h:300
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcTan(Stream stream)
Definition primitives.h:302
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:334
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTanh(Stream stream)
Definition primitives.h:336
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:351
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArgPartition(Stream stream, int kth, int axis)
Definition primitives.h:353
Definition primitives.h:372
ReduceType
Definition primitives.h:374
@ ArgMin
Definition primitives.h:375
@ ArgMax
Definition primitives.h:376
ArgReduce(Stream stream, ReduceType reduce_type, int axis)
Definition primitives.h:379
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:399
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArgSort(Stream stream, int axis)
Definition primitives.h:401
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:438
AsStrided(Stream stream, std::vector< int > shape, std::vector< size_t > strides, size_t offset)
Definition primitives.h:440
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:418
void eval_gpu(const std::vector< array > &inputs, array &out) override
AsType(Stream stream, Dtype dtype)
Definition primitives.h:420
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:465
BitwiseBinary(Stream stream, Op op)
Definition primitives.h:469
void eval_cpu(const std::vector< array > &inputs, array &out) override
Op
Definition primitives.h:467
@ And
Definition primitives.h:467
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:485
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
BlockMaskedMM(Stream stream, int block_size)
Definition primitives.h:487
Definition primitives.h:528
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Broadcast(Stream stream, const std::vector< int > &shape)
Definition primitives.h:530
Definition primitives.h:547
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Ceil(Stream stream)
Definition primitives.h:549
Definition primitives.h:2155
void eval_cpu(const std::vector< array > &inputs, array &out) override
Cholesky(Stream stream, bool upper)
Definition primitives.h:2157
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:564
Compiled(Stream stream, std::vector< array > inputs, std::vector< array > outputs, std::vector< array > tape, std::unordered_set< uintptr_t > constant_ids)
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:607
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Concatenate(Stream stream, int axis)
Definition primitives.h:609
Definition primitives.h:626
Conjugate(Stream stream)
Definition primitives.h:628
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:642
void eval_gpu(const std::vector< array > &inputs, array &out) override
Convolution(Stream stream, const std::vector< int > &kernel_strides, const std::vector< int > &padding, const std::vector< int > &kernel_dilation, const std::vector< int > &input_dilation, const int groups=1, const bool flip=false)
Definition primitives.h:644
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
Definition primitives.h:683
void eval_gpu(const std::vector< array > &inputs, array &out) override
Copy(Stream stream)
Definition primitives.h:685
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:700
void eval_cpu(const std::vector< array > &inputs, array &out) override
Cos(Stream stream)
Definition primitives.h:702
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:717
void eval_gpu(const std::vector< array > &inputs, array &out) override
Cosh(Stream stream)
Definition primitives.h:719
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:734
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
CustomTransforms(Stream stream, int num_outputs, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> vjp, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< int > &)> jvp, std::function< std::pair< std::vector< array >, std::vector< int > >(const std::vector< array > &, const std::vector< int > &)> vmap)
Definition primitives.h:736
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:786
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotan, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Depends(Stream stream)
Definition primitives.h:788
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:824
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
DivMod(Stream stream)
Definition primitives.h:826
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:807
Divide(Stream stream)
Definition primitives.h:809
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:880
Equal(Stream stream, bool equal_nan=false)
Definition primitives.h:882
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:906
Erf(Stream stream)
Definition primitives.h:908
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:923
void eval_gpu(const std::vector< array > &inputs, array &out) override
ErfInv(Stream stream)
Definition primitives.h:925
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:940
Exp(Stream stream)
Definition primitives.h:942
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:957
Expm1(Stream stream)
Definition primitives.h:959
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:973
FFT(Stream stream, const std::vector< size_t > &axes, bool inverse, bool real)
Definition primitives.h:975
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:999
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Floor(Stream stream)
Definition primitives.h:1001
Definition primitives.h:1016
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Full(Stream stream)
Definition primitives.h:1018
Definition primitives.h:1032
Gather(Stream stream, const std::vector< int > &axes, const std::vector< int > &slice_sizes)
Definition primitives.h:1034
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:508
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, array &out) override
GatherMM(Stream stream)
Definition primitives.h:510
Definition primitives.h:1521
GatherQMM(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1523
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1071
void eval_cpu(const std::vector< array > &inputs, array &out) override
GreaterEqual(Stream stream)
Definition primitives.h:1073
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1054
Greater(Stream stream)
Definition primitives.h:1056
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1088
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Hadamard(Stream stream, float scale)
Definition primitives.h:1090
Definition primitives.h:2138
void eval_gpu(const std::vector< array > &inputs, array &output) override
Inverse(Stream stream, bool tri, bool upper)
Definition primitives.h:2140
void eval_cpu(const std::vector< array > &inputs, array &output) override
Definition primitives.h:1126
LessEqual(Stream stream)
Definition primitives.h:1128
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1109
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Less(Stream stream)
Definition primitives.h:1111
Definition primitives.h:1143
void eval_gpu(const std::vector< array > &inputs, array &out) override
Load(Stream stream, std::shared_ptr< io::Reader > reader, size_t offset, bool swap_endianness=false)
Definition primitives.h:1145
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1209
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Log1p(Stream stream)
Definition primitives.h:1211
Definition primitives.h:1276
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogAddExp(Stream stream)
Definition primitives.h:1278
Definition primitives.h:1175
Base
Definition primitives.h:1177
Log(Stream stream, Base base)
Definition primitives.h:1179
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1242
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalAnd(Stream stream)
Definition primitives.h:1244
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1225
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalNot(Stream stream)
Definition primitives.h:1227
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1259
void eval_cpu(const std::vector< array > &inputs, array &out) override
LogicalOr(Stream stream)
Definition primitives.h:1261
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1293
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, array &out) override
Matmul(Stream stream)
Definition primitives.h:1295
Definition primitives.h:1311
Maximum(Stream stream)
Definition primitives.h:1313
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1328
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Minimum(Stream stream)
Definition primitives.h:1330
Definition primitives.h:1345
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Multiply(Stream stream)
Definition primitives.h:1347
Definition primitives.h:1362
void eval_gpu(const std::vector< array > &inputs, array &out) override
Negative(Stream stream)
Definition primitives.h:1364
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1379
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
NotEqual(Stream stream)
Definition primitives.h:1381
Definition primitives.h:1396
void eval_gpu(const std::vector< array > &inputs, array &out) override
NumberOfElements(Stream stream, std::vector< int > axes, bool inverted, Dtype dtype)
Definition primitives.h:1398
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1427
void eval_cpu(const std::vector< array > &inputs, array &out) override
Pad(Stream stream, const std::vector< int > &axes, const std::vector< int > &low_pad_size, const std::vector< int > &high_pad_size)
Definition primitives.h:1429
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1455
void eval_cpu(const std::vector< array > &inputs, array &out) override
Partition(Stream stream, int kth, int axis)
Definition primitives.h:1457
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1476
void eval_cpu(const std::vector< array > &inputs, array &out) override
Power(Stream stream)
Definition primitives.h:1478
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:48
virtual void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
virtual std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
The vector-Jacobian product.
virtual ~Primitive()=default
Primitive(const Primitive &other)=delete
Primitive(Primitive &&other)=delete
const Stream & stream()
The stream the primitive will run on.
Definition primitives.h:58
Primitive & operator=(Primitive &&other)=delete
virtual bool is_equivalent(const Primitive &other) const
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:107
Primitive & operator=(const Primitive &other)=delete
virtual std::vector< std::vector< int > > output_shapes(const std::vector< array > &inputs)
Get the output shapes of the primitive.
const Device & device()
The device the primitive will run on.
Definition primitives.h:53
virtual std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
The Jacobian-vector product.
virtual std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes)
The primitive must know how to vectorize itself across the given axes.
virtual void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
virtual void print(std::ostream &os)=0
Print the primitive.
Primitive(Stream stream)
Definition primitives.h:50
Definition primitives.h:2105
QRF(Stream stream)
Definition primitives.h:2107
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:1493
void eval_gpu(const std::vector< array > &inputs, array &out) override
QuantizedMatmul(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1495
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1545
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
RandomBits(Stream stream, const std::vector< int > &shape, int width)
Definition primitives.h:1547
Definition primitives.h:1591
Reduce(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1595
ReduceType
Definition primitives.h:1593
@ And
Definition primitives.h:1593
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:863
Remainder(Stream stream)
Definition primitives.h:865
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1564
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Reshape(Stream stream, const std::vector< int > &shape)
Definition primitives.h:1566
Definition primitives.h:1646
Round(Stream stream)
Definition primitives.h:1648
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2121
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
SVD(Stream stream)
Definition primitives.h:2123
Definition primitives.h:1663
void eval_cpu(const std::vector< array > &inputs, array &out) override
ReduceType
Definition primitives.h:1665
@ Max
Definition primitives.h:1665
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Scan(Stream stream, ReduceType reduce_type, int axis, bool reverse, bool inclusive)
Definition primitives.h:1667
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1713
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
ReduceType
Definition primitives.h:1715
@ Max
Definition primitives.h:1715
void eval_cpu(const std::vector< array > &inputs, array &out) override
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1729
void eval_gpu(const std::vector< array > &inputs, array &out) override
Scatter(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1717
Definition primitives.h:846
void eval_gpu(const std::vector< array > &inputs, array &out) override
Select(Stream stream)
Definition primitives.h:848
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1756
Sigmoid(Stream stream)
Definition primitives.h:1758
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1773
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Sign(Stream stream)
Definition primitives.h:1775
Definition primitives.h:1790
Sin(Stream stream)
Definition primitives.h:1792
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1807
Sinh(Stream stream)
Definition primitives.h:1809
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1824
void eval_cpu(const std::vector< array > &inputs, array &out) override
Slice(Stream stream, const std::vector< int > &start_indices, const std::vector< int > &end_indices, const std::vector< int > &strides)
Definition primitives.h:1826
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1852
SliceUpdate(Stream stream, const std::vector< int > &start_indices, const std::vector< int > &end_indices, const std::vector< int > &strides)
Definition primitives.h:1854
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1882
void eval_gpu(const std::vector< array > &inputs, array &out) override
Softmax(Stream stream, bool precise)
Definition primitives.h:1884
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1902
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sort(Stream stream, int axis)
Definition primitives.h:1904
Definition primitives.h:1922
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Split(Stream stream, const std::vector< int > &indices, int axis)
Definition primitives.h:1924
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:1961
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sqrt(Stream stream, bool recip=false)
Definition primitives.h:1963
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1944
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Square(Stream stream)
Definition primitives.h:1946
Definition primitives.h:1987
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
StopGradient(Stream stream)
Definition primitives.h:1989
Definition primitives.h:2003
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Subtract(Stream stream)
Definition primitives.h:2005
Definition primitives.h:2020
Tan(Stream stream)
Definition primitives.h:2022
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2037
void eval_gpu(const std::vector< array > &inputs, array &out) override
Tanh(Stream stream)
Definition primitives.h:2039
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2085
Transpose(Stream stream, const std::vector< int > &axes)
Definition primitives.h:2087
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:127
UnaryPrimitive & operator=(const UnaryPrimitive &other)=delete
UnaryPrimitive(Stream stream)
An abstract base class for a primitive with a single output.
Definition primitives.h:132
virtual void eval_gpu(const std::vector< array > &inputs, array &output)=0
UnaryPrimitive(UnaryPrimitive &&other)=delete
virtual void eval_cpu(const std::vector< array > &inputs, array &output)=0
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:142
UnaryPrimitive(const UnaryPrimitive &other)=delete
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:137
UnaryPrimitive & operator=(UnaryPrimitive &&other)=delete
virtual ~UnaryPrimitive()=default
Definition primitives.h:2054
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Uniform(Stream stream)
Definition primitives.h:2056
Definition primitives.h:2069
void eval_cpu(const std::vector< array > &inputs, array &out) override
View(Stream stream, Dtype dtype)
Definition primitives.h:2071
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition array.h:20
Op op
Definition binary.h:129
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
array tri(int n, int m, int k, Dtype type, StreamOrDevice s={})
array transpose(const array &a, std::vector< int > axes, StreamOrDevice s={})
Permutes the dimensions according to the given axes.
Definition allocator.h:7
std::pair< std::vector< array >, std::vector< array > > jvp(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &tangents)
Computes the output and Jacobian-vector product (JVP) of a function.
std::pair< std::vector< array >, std::vector< array > > vjp(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &cotangents)
Computes the output and vector-Jacobian product (VJP) of a function.
void eval(std::vector< array > outputs)
std::function< array(const array &)> vmap(const std::function< array(const array &)> &fun, int in_axis=0, int out_axis=0)
Automatically vectorize a unary function over the requested axes.
#define DEFINE_DEFAULT_IS_EQUIVALENT()
Definition primitives.h:34
#define DEFINE_PRINT(PRIMITIVE)
Definition primitives.h:29
#define DEFINE_INPUT_OUTPUT_SHAPE()
Definition primitives.h:39
#define DEFINE_GRADS()
Definition primitives.h:17
#define DEFINE_VMAP()
Definition primitives.h:12
Definition ops.h:37
Definition binary_ops.h:270
Definition ops.h:185
Definition ops.h:163
Definition ops.h:29
Definition ops.h:78
Definition ops.h:141
Definition binary_ops.h:277
Definition ops.h:119
Definition device.h:7
Definition dtype.h:13
Definition stream.h:9
Device device
Definition stream.h:11