MLX
Loading...
Searching...
No Matches
primitives.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <unordered_set>
6
7#include "mlx/array.h"
8#include "mlx/device.h"
9#include "mlx/io/load.h"
10#include "mlx/stream.h"
11
12#define DEFINE_VMAP() \
13 virtual std::pair<std::vector<array>, std::vector<int>> vmap( \
14 const std::vector<array>& inputs, const std::vector<int>& axes) \
15 override;
16
17#define DEFINE_GRADS() \
18 std::vector<array> jvp( \
19 const std::vector<array>& primals, \
20 const std::vector<array>& tangents, \
21 const std::vector<int>& argnums) override; \
22 \
23 std::vector<array> vjp( \
24 const std::vector<array>& primals, \
25 const std::vector<array>& cotangents, \
26 const std::vector<int>& argnums, \
27 const std::vector<array>& outputs) override;
28
29#define DEFINE_PRINT(PRIMITIVE) \
30 void print(std::ostream& os) override { \
31 os << #PRIMITIVE; \
32 }
33
34#define DEFINE_DEFAULT_IS_EQUIVALENT() \
35 bool is_equivalent(const Primitive& other) const override { \
36 return true; \
37 }
38
39#define DEFINE_INPUT_OUTPUT_SHAPE() \
40 std::vector<std::vector<int>> output_shapes( \
41 const std::vector<array>& inputs) override { \
42 return {inputs[0].shape()}; \
43 }
44
45namespace mlx::core {
46
47// Abstract base class
48class Primitive {
49 public:
50 explicit Primitive(Stream stream) : stream_(stream) {}
51
53 const Device& device() {
54 return stream().device;
55 }
56
58 const Stream& stream() {
59 return stream_;
60 }
61
69 virtual void eval_cpu(
70 const std::vector<array>& inputs,
71 std::vector<array>& outputs) = 0;
72 virtual void eval_gpu(
73 const std::vector<array>& inputs,
74 std::vector<array>& outputs) = 0;
75
79 virtual std::vector<array> jvp(
80 const std::vector<array>& primals,
81 const std::vector<array>& tangents,
82 const std::vector<int>& argnums);
83
87 virtual std::vector<array> vjp(
88 const std::vector<array>& primals,
89 const std::vector<array>& cotangents,
90 const std::vector<int>& argnums,
91 const std::vector<array>& outputs);
92
99 virtual std::pair<std::vector<array>, std::vector<int>> vmap(
100 const std::vector<array>& inputs,
101 const std::vector<int>& axes);
102
104 virtual void print(std::ostream& os) = 0;
105
107 virtual bool is_equivalent(const Primitive& other) const {
108 return false;
109 }
110
113 virtual std::vector<std::vector<int>> output_shapes(
114 const std::vector<array>& inputs);
115
116 virtual ~Primitive() = default;
117 Primitive(const Primitive& other) = delete;
118 Primitive(Primitive&& other) = delete;
119 Primitive& operator=(const Primitive& other) = delete;
120 Primitive& operator=(Primitive&& other) = delete;
121
122 private:
123 // Every primitive stores the stream it should run in
124 Stream stream_;
125};
126
127class UnaryPrimitive : public Primitive {
131 public:
133
134 virtual void eval_cpu(const std::vector<array>& inputs, array& output) = 0;
135 virtual void eval_gpu(const std::vector<array>& inputs, array& output) = 0;
136
137 inline void eval_cpu(
138 const std::vector<array>& inputs,
139 std::vector<array>& outputs) override {
140 eval_cpu(inputs, outputs[0]);
141 }
142 inline void eval_gpu(
143 const std::vector<array>& inputs,
144 std::vector<array>& outputs) override {
145 eval_gpu(inputs, outputs[0]);
146 }
147
148 virtual ~UnaryPrimitive() = default;
149 UnaryPrimitive(const UnaryPrimitive& other) = delete;
151 UnaryPrimitive& operator=(const UnaryPrimitive& other) = delete;
153};
154
155class Abs : public UnaryPrimitive {
156 public:
158
159 void eval_cpu(const std::vector<array>& inputs, array& out) override;
160 void eval_gpu(const std::vector<array>& inputs, array& out) override;
161
167
168 private:
169 void eval(const std::vector<array>& inputs, array& out);
170};
171
172class Add : public UnaryPrimitive {
173 public:
175
176 void eval_cpu(const std::vector<array>& inputs, array& out) override;
177 void eval_gpu(const std::vector<array>& inputs, array& out) override;
178
184
185 private:
186 void eval(const std::vector<array>& inputs, array& out);
187};
188
189class AddMM : public UnaryPrimitive {
190 public:
191 explicit AddMM(Stream stream, float alpha, float beta)
192 : UnaryPrimitive(stream), alpha_(alpha), beta_(beta) {}
193
194 void eval_cpu(const std::vector<array>& inputs, array& out) override;
195 void eval_gpu(const std::vector<array>& inputs, array& out) override;
196
197 std::vector<array> vjp(
198 const std::vector<array>& primals,
199 const std::vector<array>& cotangents,
200 const std::vector<int>& argnums,
201 const std::vector<array>& outputs) override;
202
205
206 bool is_equivalent(const Primitive& other) const override;
207
208 private:
209 const float alpha_;
210 const float beta_;
211};
212
213class Arange : public UnaryPrimitive {
214 public:
215 explicit Arange(Stream stream, double start, double stop, double step)
216 : UnaryPrimitive(stream), start_(start), stop_(stop), step_(step) {}
217
218 void eval_cpu(const std::vector<array>& inputs, array& out) override;
219 void eval_gpu(const std::vector<array>& inputs, array& out) override;
220
222 bool is_equivalent(const Primitive& other) const override;
223
224 private:
225 double start_;
226 double stop_;
227 double step_;
228
229 void eval(const std::vector<array>& inputs, array& out);
230};
231
232class ArcCos : public UnaryPrimitive {
233 public:
235
236 void eval_cpu(const std::vector<array>& inputs, array& out) override;
237 void eval_gpu(const std::vector<array>& inputs, array& out) override;
238
244
245 private:
246 void eval(const std::vector<array>& inputs, array& out);
247};
248
249class ArcCosh : public UnaryPrimitive {
250 public:
252
253 void eval_cpu(const std::vector<array>& inputs, array& out) override;
254 void eval_gpu(const std::vector<array>& inputs, array& out) override;
255
261
262 private:
263 void eval(const std::vector<array>& inputs, array& out);
264};
265
266class ArcSin : public UnaryPrimitive {
267 public:
269
270 void eval_cpu(const std::vector<array>& inputs, array& out) override;
271 void eval_gpu(const std::vector<array>& inputs, array& out) override;
272
278
279 private:
280 void eval(const std::vector<array>& inputs, array& out);
281};
282
283class ArcSinh : public UnaryPrimitive {
284 public:
286
287 void eval_cpu(const std::vector<array>& inputs, array& out) override;
288 void eval_gpu(const std::vector<array>& inputs, array& out) override;
289
295
296 private:
297 void eval(const std::vector<array>& inputs, array& out);
298};
299
300class ArcTan : public UnaryPrimitive {
301 public:
303
304 void eval_cpu(const std::vector<array>& inputs, array& out) override;
305 void eval_gpu(const std::vector<array>& inputs, array& out) override;
306
312
313 private:
314 void eval(const std::vector<array>& inputs, array& out);
315};
316
317class ArcTan2 : public UnaryPrimitive {
318 public:
320
321 void eval_cpu(const std::vector<array>& inputs, array& out) override;
322 void eval_gpu(const std::vector<array>& inputs, array& out) override;
323
329
330 private:
331 void eval(const std::vector<array>& inputs, array& out);
332};
333
334class ArcTanh : public UnaryPrimitive {
335 public:
337
338 void eval_cpu(const std::vector<array>& inputs, array& out) override;
339 void eval_gpu(const std::vector<array>& inputs, array& out) override;
340
346
347 private:
348 void eval(const std::vector<array>& inputs, array& out);
349};
350
352 public:
353 explicit ArgPartition(Stream stream, int kth, int axis)
354 : UnaryPrimitive(stream), kth_(kth), axis_(axis) {}
355
356 void eval_cpu(const std::vector<array>& inputs, array& out) override;
357 void eval_gpu(const std::vector<array>& inputs, array& out) override;
358
363 bool is_equivalent(const Primitive& other) const override;
364
365 private:
366 int kth_;
367 int axis_;
368
369 void eval(const std::vector<array>& inputs, array& out);
370};
371
372class ArgReduce : public UnaryPrimitive {
373 public:
378
379 explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis)
380 : UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {}
381
382 void eval_cpu(const std::vector<array>& inputs, array& out) override;
383 void eval_gpu(const std::vector<array>& inputs, array& out) override;
384
388 bool is_equivalent(const Primitive& other) const override;
389 std::vector<std::vector<int>> output_shapes(
390 const std::vector<array>& inputs) override;
391
392 private:
393 ReduceType reduce_type_;
394 int axis_;
395
396 void eval(const std::vector<array>& inputs, array& out);
397};
398
399class ArgSort : public UnaryPrimitive {
400 public:
401 explicit ArgSort(Stream stream, int axis)
402 : UnaryPrimitive(stream), axis_(axis) {}
403
404 void eval_cpu(const std::vector<array>& inputs, array& out) override;
405 void eval_gpu(const std::vector<array>& inputs, array& out) override;
406
410 bool is_equivalent(const Primitive& other) const override;
411
412 private:
413 int axis_;
414
415 void eval(const std::vector<array>& inputs, array& out);
416};
417
418class AsType : public UnaryPrimitive {
419 public:
420 explicit AsType(Stream stream, Dtype dtype)
421 : UnaryPrimitive(stream), dtype_(dtype) {}
422
423 void eval_cpu(const std::vector<array>& inputs, array& out) override;
424 void eval_gpu(const std::vector<array>& inputs, array& out) override;
425
430 bool is_equivalent(const Primitive& other) const override;
431
432 private:
433 Dtype dtype_;
434
435 void eval(const std::vector<array>& inputs, array& out);
436};
437
438class AsStrided : public UnaryPrimitive {
439 public:
440 explicit AsStrided(
442 std::vector<int> shape,
443 std::vector<size_t> strides,
444 size_t offset)
446 shape_(std::move(shape)),
447 strides_(std::move(strides)),
448 offset_(offset) {}
449
450 void eval_cpu(const std::vector<array>& inputs, array& out) override;
451 void eval_gpu(const std::vector<array>& inputs, array& out) override;
452
455 bool is_equivalent(const Primitive& other) const override;
456
457 private:
458 std::vector<int> shape_;
459 std::vector<size_t> strides_;
460 size_t offset_;
461
462 void eval(const std::vector<array>& inputs, array& out);
463};
464
466 public:
467 enum Op { And, Or, Xor, LeftShift, RightShift };
468
471
472 void eval_cpu(const std::vector<array>& inputs, array& out) override;
473 void eval_gpu(const std::vector<array>& inputs, array& out) override;
474
477 bool is_equivalent(const Primitive& other) const override;
478 void print(std::ostream& os) override;
480
481 private:
482 Op op_;
483};
484
486 public:
487 explicit BlockMaskedMM(Stream stream, int block_size)
488 : UnaryPrimitive(stream), block_size_(block_size) {}
489
490 void eval_cpu(const std::vector<array>& inputs, array& out) override;
491 void eval_gpu(const std::vector<array>& inputs, array& out) override;
492
493 std::vector<array> vjp(
494 const std::vector<array>& primals,
495 const std::vector<array>& cotangents,
496 const std::vector<int>& argnums,
497 const std::vector<array>& outputs) override;
498
500 bool is_equivalent(const Primitive& other) const override;
501
502 private:
503 int block_size_;
504
505 void eval(const std::vector<array>& inputs, array& out);
506};
507
508class GatherMM : public UnaryPrimitive {
509 public:
511
512 void eval_cpu(const std::vector<array>& inputs, array& out) override;
513 void eval_gpu(const std::vector<array>& inputs, array& out) override;
514
515 std::vector<array> vjp(
516 const std::vector<array>& primals,
517 const std::vector<array>& cotangents,
518 const std::vector<int>& argnums,
519 const std::vector<array>& outputs) override;
520
523
524 private:
525 void eval(const std::vector<array>& inputs, array& out);
526};
527
528class Broadcast : public UnaryPrimitive {
529 public:
530 explicit Broadcast(Stream stream, const std::vector<int>& shape)
531 : UnaryPrimitive(stream), shape_(shape) {}
532
533 void eval_cpu(const std::vector<array>& inputs, array& out) override;
534 void eval_gpu(const std::vector<array>& inputs, array& out) override;
535
539 bool is_equivalent(const Primitive& other) const override;
540
541 private:
542 std::vector<int> shape_;
543
544 void eval(const std::vector<array>& inputs, array& out);
545};
546
547class Ceil : public UnaryPrimitive {
548 public:
550
551 void eval_cpu(const std::vector<array>& inputs, array& out) override;
552 void eval_gpu(const std::vector<array>& inputs, array& out) override;
553
559
560 private:
561 void eval(const std::vector<array>& inputs, array& out);
562};
563
564class Compiled : public Primitive {
565 public:
566 /*
567 * The inputs, outputs and tape are either tracers or constants.
568 * - The tape should not contain the inputs, but it should contain the
569 * outputs.
570 * - The tape should also have only one array per primitive for multi-output
571 * primitives.
572 * - The constant_ids contains ids of arrays in the input list that are safe
573 * to treat as scalar constants.
574 */
575 explicit Compiled(
577 std::vector<array> inputs,
578 std::vector<array> outputs,
579 std::vector<array> tape,
580 std::unordered_set<uintptr_t> constant_ids);
581
582 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
583 override;
584 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
585 override;
586
589 std::vector<std::vector<int>> output_shapes(
590 const std::vector<array>& inputs) override;
591 void print(std::ostream& os) override;
592 bool is_equivalent(const Primitive& other) const override;
593
594 std::string lib_name() const {
595 return kernel_lib_;
596 }
597
598 private:
599 const std::vector<array> inputs_;
600 const std::vector<array> outputs_;
601 const std::vector<array> tape_;
602 const std::unordered_set<uintptr_t> constant_ids_;
603
604 std::string kernel_lib_;
605};
606
608 public:
609 explicit Concatenate(Stream stream, int axis)
610 : UnaryPrimitive(stream), axis_(axis) {}
611
612 void eval_cpu(const std::vector<array>& inputs, array& out) override;
613 void eval_gpu(const std::vector<array>& inputs, array& out) override;
614
618 bool is_equivalent(const Primitive& other) const override;
619
620 private:
621 int axis_;
622
623 void eval(const std::vector<array>& inputs, array& out);
624};
625
626class Conjugate : public UnaryPrimitive {
627 public:
628 explicit Conjugate(Stream stream) : UnaryPrimitive(stream) {}
629
630 void eval_cpu(const std::vector<array>& inputs, array& out) override;
631 void eval_gpu(const std::vector<array>& inputs, array& out) override;
632
637
638 private:
639 void eval(const std::vector<array>& inputs, array& out);
640};
641
643 public:
644 explicit Convolution(
645 Stream stream,
646 const std::vector<int>& kernel_strides,
647 const std::vector<int>& padding,
648 const std::vector<int>& kernel_dilation,
649 const std::vector<int>& input_dilation,
650 const int groups = 1,
651 const bool flip = false)
652 : UnaryPrimitive(stream),
653 padding_(padding),
654 kernel_strides_(kernel_strides),
655 kernel_dilation_(kernel_dilation),
656 input_dilation_(input_dilation),
657 groups_(groups),
658 flip_(flip) {}
659
660 void eval_cpu(const std::vector<array>& inputs, array& out) override;
661 void eval_gpu(const std::vector<array>& inputs, array& out) override;
662
663 std::vector<array> vjp(
664 const std::vector<array>& primals,
665 const std::vector<array>& cotangents,
666 const std::vector<int>& argnums,
667 const std::vector<array>& outputs) override;
668
670 bool is_equivalent(const Primitive& other) const override;
671
672 private:
673 std::vector<int> padding_;
674 std::vector<int> kernel_strides_;
675 std::vector<int> kernel_dilation_;
676 std::vector<int> input_dilation_;
677 int groups_;
678 bool flip_;
679
680 void eval(const std::vector<array>& inputs, array& out);
681};
682
683class Copy : public UnaryPrimitive {
684 public:
685 explicit Copy(Stream stream) : UnaryPrimitive(stream) {}
686
687 void eval_cpu(const std::vector<array>& inputs, array& out) override;
688 void eval_gpu(const std::vector<array>& inputs, array& out) override;
689
695
696 private:
697 void eval(const std::vector<array>& inputs, array& out);
698};
699
700class Cos : public UnaryPrimitive {
701 public:
702 explicit Cos(Stream stream) : UnaryPrimitive(stream) {}
703
704 void eval_cpu(const std::vector<array>& inputs, array& out) override;
705 void eval_gpu(const std::vector<array>& inputs, array& out) override;
706
712
713 private:
714 void eval(const std::vector<array>& inputs, array& out);
715};
716
717class Cosh : public UnaryPrimitive {
718 public:
719 explicit Cosh(Stream stream) : UnaryPrimitive(stream) {}
720
721 void eval_cpu(const std::vector<array>& inputs, array& out) override;
722 void eval_gpu(const std::vector<array>& inputs, array& out) override;
723
729
730 private:
731 void eval(const std::vector<array>& inputs, array& out);
732};
733
735 public:
737 Stream stream,
738 int num_outputs,
739 std::function<std::vector<array>(
740 const std::vector<array>&,
741 const std::vector<array>&,
742 const std::vector<array>&)> vjp,
743 std::function<std::vector<array>(
744 const std::vector<array>&,
745 const std::vector<array>&,
746 const std::vector<int>&)> jvp,
747 std::function<std::pair<std::vector<array>, std::vector<int>>(
748 const std::vector<array>&,
749 const std::vector<int>&)> vmap)
750 : Primitive(stream),
751 num_outputs_(num_outputs),
752 vjp_fun_(std::move(vjp)),
753 jvp_fun_(std::move(jvp)),
754 vmap_fun_(std::move(vmap)) {}
755
756 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
757 override;
758 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
759 override;
760
764
765 private:
766 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
767
768 int num_outputs_;
769
770 std::function<std::vector<array>(
771 const std::vector<array>&,
772 const std::vector<array>&,
773 const std::vector<array>&)>
774 vjp_fun_;
775 std::function<std::vector<array>(
776 const std::vector<array>&,
777 const std::vector<array>&,
778 const std::vector<int>&)>
779 jvp_fun_;
780 std::function<std::pair<std::vector<array>, std::vector<int>>(
781 const std::vector<array>&,
782 const std::vector<int>&)>
783 vmap_fun_;
784};
785
786class Depends : public Primitive {
787 public:
788 explicit Depends(Stream stream) : Primitive(stream) {}
789
790 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
791 override;
792 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
793 override;
794
795 std::vector<array> vjp(
796 const std::vector<array>& primals,
797 const std::vector<array>& cotan,
798 const std::vector<int>& argnums,
799 const std::vector<array>& outputs) override;
800
802
803 private:
804 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
805};
806
807class Divide : public UnaryPrimitive {
808 public:
809 explicit Divide(Stream stream) : UnaryPrimitive(stream) {}
810
811 void eval_cpu(const std::vector<array>& inputs, array& out) override;
812 void eval_gpu(const std::vector<array>& inputs, array& out) override;
813
819
820 private:
821 void eval(const std::vector<array>& inputs, array& out);
822};
823
824class DivMod : public Primitive {
825 public:
826 explicit DivMod(Stream stream) : Primitive(stream) {}
827
828 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
829 override;
830 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
831 override;
832
837 std::vector<std::vector<int>> output_shapes(
838 const std::vector<array>& inputs) override {
839 return std::vector{inputs[0].shape(), inputs[0].shape()};
840 }
841
842 private:
843 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
844};
845
846class Select : public UnaryPrimitive {
847 public:
848 explicit Select(Stream stream) : UnaryPrimitive(stream) {}
849
850 void eval_cpu(const std::vector<array>& inputs, array& out) override;
851 void eval_gpu(const std::vector<array>& inputs, array& out) override;
852
858
859 private:
860 void eval(const std::vector<array>& inputs, array& out);
861};
862
863class Remainder : public UnaryPrimitive {
864 public:
865 explicit Remainder(Stream stream) : UnaryPrimitive(stream) {}
866
867 void eval_cpu(const std::vector<array>& inputs, array& out) override;
868 void eval_gpu(const std::vector<array>& inputs, array& out) override;
869
875
876 private:
877 void eval(const std::vector<array>& inputs, array& out);
878};
879
880class Equal : public UnaryPrimitive {
881 public:
882 explicit Equal(Stream stream, bool equal_nan = false)
883 : UnaryPrimitive(stream), equal_nan_(equal_nan) {}
884
885 void eval_cpu(const std::vector<array>& inputs, array& out) override;
886 void eval_gpu(const std::vector<array>& inputs, array& out) override;
887
892
893 void print(std::ostream& os) override {
894 if (equal_nan_) {
895 os << "NaNEqual";
896 } else {
897 os << "Equal";
898 }
899 }
900
901 private:
902 void eval(const std::vector<array>& inputs, array& out);
903 bool equal_nan_;
904};
905
906class Erf : public UnaryPrimitive {
907 public:
908 explicit Erf(Stream stream) : UnaryPrimitive(stream) {}
909
910 void eval_cpu(const std::vector<array>& inputs, array& out) override;
911 void eval_gpu(const std::vector<array>& inputs, array& out) override;
912
918
919 private:
920 void eval(const std::vector<array>& inputs, array& out);
921};
922
923class ErfInv : public UnaryPrimitive {
924 public:
925 explicit ErfInv(Stream stream) : UnaryPrimitive(stream) {}
926
927 void eval_cpu(const std::vector<array>& inputs, array& out) override;
928 void eval_gpu(const std::vector<array>& inputs, array& out) override;
929
935
936 private:
937 void eval(const std::vector<array>& inputs, array& out);
938};
939
940class Exp : public UnaryPrimitive {
941 public:
942 explicit Exp(Stream stream) : UnaryPrimitive(stream) {}
943
944 void eval_cpu(const std::vector<array>& inputs, array& out) override;
945 void eval_gpu(const std::vector<array>& inputs, array& out) override;
946
952
953 private:
954 void eval(const std::vector<array>& inputs, array& out);
955};
956
957class Expm1 : public UnaryPrimitive {
958 public:
959 explicit Expm1(Stream stream) : UnaryPrimitive(stream) {}
960
961 void eval_cpu(const std::vector<array>& inputs, array& out) override;
962 void eval_gpu(const std::vector<array>& inputs, array& out) override;
963
968
969 private:
970 void eval(const std::vector<array>& inputs, array& out);
971};
972
973class FFT : public UnaryPrimitive {
974 public:
975 explicit FFT(
976 Stream stream,
977 const std::vector<size_t>& axes,
978 bool inverse,
979 bool real)
980 : UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {}
981
982 void eval_cpu(const std::vector<array>& inputs, array& out) override;
983 void eval_gpu(const std::vector<array>& inputs, array& out) override;
984
988
989 bool is_equivalent(const Primitive& other) const override;
990
991 private:
992 std::vector<size_t> axes_;
993 bool inverse_;
994 bool real_;
995
996 void eval(const std::vector<array>& inputs, array& out);
997};
998
999class Floor : public UnaryPrimitive {
1000 public:
1001 explicit Floor(Stream stream) : UnaryPrimitive(stream) {}
1002
1003 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1004 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1005
1011
1012 private:
1013 void eval(const std::vector<array>& inputs, array& out);
1014};
1015
1016class Full : public UnaryPrimitive {
1017 public:
1018 explicit Full(Stream stream) : UnaryPrimitive(stream) {}
1019
1020 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1021 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1022
1027
1028 private:
1029 void eval(const std::vector<array>& inputs, array& out);
1030};
1031
1032class Gather : public UnaryPrimitive {
1033 public:
1034 explicit Gather(
1035 Stream stream,
1036 const std::vector<int>& axes,
1037 const std::vector<int>& slice_sizes)
1038 : UnaryPrimitive(stream), axes_(axes), slice_sizes_(slice_sizes) {}
1039
1040 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1041 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1042
1046 bool is_equivalent(const Primitive& other) const override;
1047
1048 private:
1049 void eval(const std::vector<array>& inputs, array& out);
1050 std::vector<int> axes_;
1051 std::vector<int> slice_sizes_;
1052};
1053
1054class Greater : public UnaryPrimitive {
1055 public:
1056 explicit Greater(Stream stream) : UnaryPrimitive(stream) {}
1057
1058 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1059 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1060
1066
1067 private:
1068 void eval(const std::vector<array>& inputs, array& out);
1069};
1070
1072 public:
1073 explicit GreaterEqual(Stream stream) : UnaryPrimitive(stream) {}
1074
1075 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1076 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1077
1083
1084 private:
1085 void eval(const std::vector<array>& inputs, array& out);
1086};
1087
1088class Hadamard : public UnaryPrimitive {
1089 public:
1090 explicit Hadamard(Stream stream, float scale)
1091 : UnaryPrimitive(stream), scale_(scale) {}
1092
1093 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1094 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1095
1100
1101 bool is_equivalent(const Primitive& other) const override;
1102
1103 private:
1104 float scale_;
1105
1106 void eval(const std::vector<array>& inputs, array& out);
1107};
1108
1109class Imag : public UnaryPrimitive {
1110 public:
1111 explicit Imag(Stream stream) : UnaryPrimitive(stream) {}
1112
1113 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1114 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1115
1121};
1122
1123class Less : public UnaryPrimitive {
1124 public:
1125 explicit Less(Stream stream) : UnaryPrimitive(stream) {}
1126
1127 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1128 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1129
1135
1136 private:
1137 void eval(const std::vector<array>& inputs, array& out);
1138};
1139
1141 public:
1142 explicit LessEqual(Stream stream) : UnaryPrimitive(stream) {}
1143
1144 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1145 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1146
1152
1153 private:
1154 void eval(const std::vector<array>& inputs, array& out);
1155};
1156
1157class Load : public UnaryPrimitive {
1158 public:
1159 explicit Load(
1160 Stream stream,
1161 std::shared_ptr<io::Reader> reader,
1162 size_t offset,
1163 bool swap_endianness = false)
1164 : UnaryPrimitive(stream),
1165 reader_(std::move(reader)),
1166 offset_(offset),
1167 swap_endianness_(swap_endianness) {
1168 if (stream.device == Device::gpu) {
1169 io_stream();
1170 }
1171 }
1172
1173 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1174 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1175
1177
1178 private:
1179 Stream& io_stream() {
1180 static Stream io_stream = new_stream(Device::cpu);
1181 return io_stream;
1182 };
1183 void eval(const std::vector<array>& inputs, array& out);
1184 std::shared_ptr<io::Reader> reader_;
1185 size_t offset_;
1186 bool swap_endianness_;
1187};
1188
1189class Log : public UnaryPrimitive {
1190 public:
1191 enum Base { two, ten, e };
1192
1193 explicit Log(Stream stream, Base base)
1194 : UnaryPrimitive(stream), base_(base) {}
1195
1196 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1197 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1198
1203
1204 void print(std::ostream& os) override {
1205 switch (base_) {
1206 case e:
1207 os << "Log";
1208 break;
1209 case two:
1210 os << "Log2";
1211 break;
1212 case ten:
1213 os << "Log10";
1214 break;
1215 }
1216 }
1217
1218 private:
1219 Base base_;
1220 void eval(const std::vector<array>& inputs, array& out);
1221};
1222
1223class Log1p : public UnaryPrimitive {
1224 public:
1225 explicit Log1p(Stream stream) : UnaryPrimitive(stream) {}
1226
1227 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1228 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1229
1234
1235 private:
1236 void eval(const std::vector<array>& inputs, array& out);
1237};
1238
1240 public:
1241 explicit LogicalNot(Stream stream) : UnaryPrimitive(stream) {}
1242
1243 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1244 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1245
1251
1252 private:
1253 void eval(const std::vector<array>& inputs, array& out);
1254};
1255
1257 public:
1258 explicit LogicalAnd(Stream stream) : UnaryPrimitive(stream) {}
1259
1260 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1261 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1262
1268
1269 private:
1270 void eval(const std::vector<array>& inputs, array& out);
1271};
1272
1274 public:
1275 explicit LogicalOr(Stream stream) : UnaryPrimitive(stream) {}
1276
1277 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1278 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1279
1285
1286 private:
1287 void eval(const std::vector<array>& inputs, array& out);
1288};
1289
1291 public:
1292 explicit LogAddExp(Stream stream) : UnaryPrimitive(stream) {}
1293
1294 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1295 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1296
1302
1303 private:
1304 void eval(const std::vector<array>& inputs, array& out);
1305};
1306
1307class Matmul : public UnaryPrimitive {
1308 public:
1309 explicit Matmul(Stream stream) : UnaryPrimitive(stream) {}
1310
1311 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1312 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1313
1314 std::vector<array> vjp(
1315 const std::vector<array>& primals,
1316 const std::vector<array>& cotangents,
1317 const std::vector<int>& argnums,
1318 const std::vector<array>& outputs) override;
1319
1323};
1324
1325class Maximum : public UnaryPrimitive {
1326 public:
1327 explicit Maximum(Stream stream) : UnaryPrimitive(stream) {}
1328
1329 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1330 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1331
1337
1338 private:
1339 void eval(const std::vector<array>& inputs, array& out);
1340};
1341
1342class Minimum : public UnaryPrimitive {
1343 public:
1344 explicit Minimum(Stream stream) : UnaryPrimitive(stream) {}
1345
1346 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1347 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1348
1354
1355 private:
1356 void eval(const std::vector<array>& inputs, array& out);
1357};
1358
1359class Multiply : public UnaryPrimitive {
1360 public:
1361 explicit Multiply(Stream stream) : UnaryPrimitive(stream) {}
1362
1363 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1364 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1365
1371
1372 private:
1373 void eval(const std::vector<array>& inputs, array& out);
1374};
1375
1376class Negative : public UnaryPrimitive {
1377 public:
1378 explicit Negative(Stream stream) : UnaryPrimitive(stream) {}
1379
1380 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1381 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1382
1388
1389 private:
1390 void eval(const std::vector<array>& inputs, array& out);
1391};
1392
1393class NotEqual : public UnaryPrimitive {
1394 public:
1395 explicit NotEqual(Stream stream) : UnaryPrimitive(stream) {}
1396
1397 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1398 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1399
1405
1406 private:
1407 void eval(const std::vector<array>& inputs, array& out);
1408};
1409
1411 public:
1413 Stream stream,
1414 std::vector<int> axes,
1415 bool inverted,
1416 Dtype dtype)
1417 : UnaryPrimitive(stream),
1418 axes_(std::move(axes)),
1419 inverted_(inverted),
1420 dtype_(dtype) {}
1421
1422 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1423 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1424
1427 bool is_equivalent(const Primitive& other) const override;
1428 std::vector<std::vector<int>> output_shapes(
1429 const std::vector<array>& inputs) override {
1430 return {{}};
1431 }
1432
1433 private:
1434 std::vector<int> axes_;
1435 bool inverted_;
1436 Dtype dtype_;
1437
1438 void eval(const std::vector<array>& inputs, array& out);
1439};
1440
1441class Pad : public UnaryPrimitive {
1442 public:
1443 explicit Pad(
1444 Stream stream,
1445 const std::vector<int>& axes,
1446 const std::vector<int>& low_pad_size,
1447 const std::vector<int>& high_pad_size)
1448 : UnaryPrimitive(stream),
1449 axes_(axes),
1450 low_pad_size_(low_pad_size),
1451 high_pad_size_(high_pad_size) {}
1452
1453 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1454 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1455
1459 bool is_equivalent(const Primitive& other) const override;
1460
1461 private:
1462 std::vector<int> axes_;
1463 std::vector<int> low_pad_size_;
1464 std::vector<int> high_pad_size_;
1465
1466 void eval(const std::vector<array>& inputs, array& out);
1467};
1468
1470 public:
1471 explicit Partition(Stream stream, int kth, int axis)
1472 : UnaryPrimitive(stream), kth_(kth), axis_(axis) {}
1473
1474 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1475 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1476
1481 bool is_equivalent(const Primitive& other) const override;
1482
1483 private:
1484 int kth_;
1485 int axis_;
1486
1487 void eval(const std::vector<array>& inputs, array& out);
1488};
1489
1490class Power : public UnaryPrimitive {
1491 public:
1492 explicit Power(Stream stream) : UnaryPrimitive(stream) {}
1493
1494 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1495 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1496
1502
1503 private:
1504 void eval(const std::vector<array>& inputs, array& out);
1505};
1506
1508 public:
1510 Stream stream,
1511 int group_size,
1512 int bits,
1513 bool transpose)
1514 : UnaryPrimitive(stream),
1515 group_size_(group_size),
1516 bits_(bits),
1517 transpose_(transpose) {}
1518
1519 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1520 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1521
1525 bool is_equivalent(const Primitive& other) const override;
1526
1527 private:
1528 int group_size_;
1529 int bits_;
1530 bool transpose_;
1531
1532 void eval(const std::vector<array>& inputs, array& out);
1533};
1534
1536 public:
1537 explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose)
1538 : UnaryPrimitive(stream),
1539 group_size_(group_size),
1540 bits_(bits),
1541 transpose_(transpose) {}
1542
1543 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1544 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1545
1549 bool is_equivalent(const Primitive& other) const override;
1550
1551 private:
1552 int group_size_;
1553 int bits_;
1554 bool transpose_;
1555
1556 void eval(const std::vector<array>& inputs, array& out);
1557};
1558
1560 public:
1561 explicit RandomBits(Stream stream, const std::vector<int>& shape, int width)
1562 : UnaryPrimitive(stream), shape_(shape), width_(width) {}
1563
1564 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1565 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1566
1569 bool is_equivalent(const Primitive& other) const override;
1570
1571 private:
1572 std::vector<int> shape_;
1573 int width_;
1574
1575 void eval(const std::vector<array>& inputs, array& out);
1576};
1577
1578class Real : public UnaryPrimitive {
1579 public:
1580 explicit Real(Stream stream) : UnaryPrimitive(stream) {}
1581
1582 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1583 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1584
1590};
1591
1592class Reshape : public UnaryPrimitive {
1593 public:
1594 explicit Reshape(Stream stream, const std::vector<int>& shape)
1595 : UnaryPrimitive(stream), shape_(shape) {}
1596
1597 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1598 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1599
1603 bool is_equivalent(const Primitive& other) const override;
1604
1605 private:
1606 std::vector<int> shape_;
1607
1608 void eval(const std::vector<array>& inputs, array& out);
1609
1610 std::pair<bool, std::vector<size_t>> prepare_reshape(
1611 const array& in,
1612 const array& out);
1613 void shared_buffer_reshape(
1614 const array& in,
1615 const std::vector<size_t>& out_strides,
1616 array& out);
1617};
1618
1619class Reduce : public UnaryPrimitive {
1620 public:
1621 enum ReduceType { And, Or, Sum, Prod, Min, Max };
1622
1623 explicit Reduce(
1624 Stream stream,
1625 ReduceType reduce_type,
1626 const std::vector<int>& axes)
1627 : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}
1628
1629 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1630 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1631
1633
1634 std::vector<array> vjp(
1635 const std::vector<array>& primals,
1636 const std::vector<array>& cotangents,
1637 const std::vector<int>& argnums,
1638 const std::vector<array>& outputs) override;
1639
1640 std::vector<std::vector<int>> output_shapes(
1641 const std::vector<array>& inputs) override;
1642
1643 void print(std::ostream& os) override {
1644 switch (reduce_type_) {
1645 case And:
1646 os << "And";
1647 break;
1648 case Or:
1649 os << "Or";
1650 break;
1651 case Sum:
1652 os << "Sum";
1653 break;
1654 case Prod:
1655 os << "Prod";
1656 break;
1657 case Min:
1658 os << "Min";
1659 break;
1660 case Max:
1661 os << "Max";
1662 break;
1663 }
1664 }
1665 bool is_equivalent(const Primitive& other) const override;
1666
1667 private:
1668 ReduceType reduce_type_;
1669 std::vector<int> axes_;
1670
1671 void eval(const std::vector<array>& inputs, array& out);
1672};
1673
1674class Round : public UnaryPrimitive {
1675 public:
1676 explicit Round(Stream stream) : UnaryPrimitive(stream) {}
1677
1678 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1679 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1680
1686
1687 private:
1688 void eval(const std::vector<array>& inputs, array& out);
1689};
1690
1691class Scan : public UnaryPrimitive {
1692 public:
1694
1695 explicit Scan(
1696 Stream stream,
1697 ReduceType reduce_type,
1698 int axis,
1699 bool reverse,
1700 bool inclusive)
1701 : UnaryPrimitive(stream),
1702 reduce_type_(reduce_type),
1703 axis_(axis),
1704 reverse_(reverse),
1705 inclusive_(inclusive) {}
1706
1707 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1708 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1709
1712
1713 void print(std::ostream& os) override {
1714 os << "Cum";
1715 switch (reduce_type_) {
1716 case Sum:
1717 os << "Sum";
1718 break;
1719 case Prod:
1720 os << "Prod";
1721 break;
1722 case Min:
1723 os << "Min";
1724 break;
1725 case Max:
1726 os << "Max";
1727 break;
1728 }
1729 }
1730 bool is_equivalent(const Primitive& other) const override;
1731
1732 private:
1733 ReduceType reduce_type_;
1734 int axis_;
1735 bool reverse_;
1736 bool inclusive_;
1737
1738 void eval(const std::vector<array>& inputs, array& out);
1739};
1740
1741class Scatter : public UnaryPrimitive {
1742 public:
1744
1745 explicit Scatter(
1746 Stream stream,
1747 ReduceType reduce_type,
1748 const std::vector<int>& axes)
1749 : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}
1750
1751 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1752 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1753
1756
1757 void print(std::ostream& os) override {
1758 os << "Scatter";
1759 switch (reduce_type_) {
1760 case Sum:
1761 os << " Sum";
1762 break;
1763 case Prod:
1764 os << " Prod";
1765 break;
1766 case Min:
1767 os << " Min";
1768 break;
1769 case Max:
1770 os << " Max";
1771 break;
1772 case None:
1773 break;
1774 }
1775 }
1776 bool is_equivalent(const Primitive& other) const override;
1777
1778 private:
1779 void eval(const std::vector<array>& inputs, array& out);
1780 ReduceType reduce_type_;
1781 std::vector<int> axes_;
1782};
1783
1784class Sigmoid : public UnaryPrimitive {
1785 public:
1786 explicit Sigmoid(Stream stream) : UnaryPrimitive(stream) {}
1787
1788 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1789 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1790
1796
1797 private:
1798 void eval(const std::vector<array>& inputs, array& out);
1799};
1800
1801class Sign : public UnaryPrimitive {
1802 public:
1803 explicit Sign(Stream stream) : UnaryPrimitive(stream) {}
1804
1805 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1806 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1807
1813
1814 private:
1815 void eval(const std::vector<array>& inputs, array& out);
1816};
1817
1818class Sin : public UnaryPrimitive {
1819 public:
1820 explicit Sin(Stream stream) : UnaryPrimitive(stream) {}
1821
1822 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1823 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1824
1830
1831 private:
1832 void eval(const std::vector<array>& inputs, array& out);
1833};
1834
1835class Sinh : public UnaryPrimitive {
1836 public:
1837 explicit Sinh(Stream stream) : UnaryPrimitive(stream) {}
1838
1839 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1840 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1841
1847
1848 private:
1849 void eval(const std::vector<array>& inputs, array& out);
1850};
1851
1852class Slice : public UnaryPrimitive {
1853 public:
1854 explicit Slice(
1855 Stream stream,
1856 const std::vector<int>& start_indices,
1857 const std::vector<int>& end_indices,
1858 const std::vector<int>& strides)
1859 : UnaryPrimitive(stream),
1860 start_indices_(start_indices),
1861 end_indices_(end_indices),
1862 strides_(strides) {}
1863
1864 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1865 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1866
1870 bool is_equivalent(const Primitive& other) const override;
1871
1872 private:
1873 std::vector<int> start_indices_;
1874 std::vector<int> end_indices_;
1875 std::vector<int> strides_;
1876
1877 void eval(const std::vector<array>& inputs, array& out);
1878};
1879
1881 public:
1882 explicit SliceUpdate(
1883 Stream stream,
1884 const std::vector<int>& start_indices,
1885 const std::vector<int>& end_indices,
1886 const std::vector<int>& strides)
1887 : UnaryPrimitive(stream),
1888 start_indices_(start_indices),
1889 end_indices_(end_indices),
1890 strides_(strides) {}
1891
1892 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1893 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1894
1898 bool is_equivalent(const Primitive& other) const override;
1899
1900 private:
1901 std::vector<int> start_indices_;
1902 std::vector<int> end_indices_;
1903 std::vector<int> strides_;
1904
1905 void eval(const std::vector<array>& inputs, array& out);
1906
1907 std::tuple<int64_t, std::vector<int64_t>> prepare_slice(const array& in);
1908};
1909
1910class Softmax : public UnaryPrimitive {
1911 public:
1912 explicit Softmax(Stream stream, bool precise)
1913 : UnaryPrimitive(stream), precise_(precise) {}
1914
1915 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1916 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1917
1922
1923 bool is_equivalent(const Primitive& other) const override;
1924
1925 private:
1926 void eval(const std::vector<array>& inputs, array& out);
1927 bool precise_;
1928};
1929
1930class Sort : public UnaryPrimitive {
1931 public:
1932 explicit Sort(Stream stream, int axis)
1933 : UnaryPrimitive(stream), axis_(axis) {}
1934
1935 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1936 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1937
1942 bool is_equivalent(const Primitive& other) const override;
1943
1944 private:
1945 int axis_;
1946
1947 void eval(const std::vector<array>& inputs, array& out);
1948};
1949
1950class Split : public Primitive {
1951 public:
1952 explicit Split(Stream stream, const std::vector<int>& indices, int axis)
1953 : Primitive(stream), indices_(indices), axis_(axis) {}
1954
1955 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
1956 override;
1957 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
1958 override;
1959
1963 bool is_equivalent(const Primitive& other) const override;
1964
1965 private:
1966 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
1967
1968 std::vector<int> indices_;
1969 int axis_;
1970};
1971
1972class Square : public UnaryPrimitive {
1973 public:
1974 explicit Square(Stream stream) : UnaryPrimitive(stream) {}
1975
1976 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1977 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1978
1984
1985 private:
1986 void eval(const std::vector<array>& inputs, array& out);
1987};
1988
1989class Sqrt : public UnaryPrimitive {
1990 public:
1991 explicit Sqrt(Stream stream, bool recip = false)
1992 : UnaryPrimitive(stream), recip_(recip) {}
1993
1994 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1995 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1996
2000 bool is_equivalent(const Primitive& other) const override;
2001
2002 void print(std::ostream& os) override {
2003 if (recip_) {
2004 os << "Rsqrt";
2005 } else {
2006 os << "Sqrt";
2007 }
2008 }
2009
2010 private:
2011 void eval(const std::vector<array>& inputs, array& out);
2012 bool recip_;
2013};
2014
2016 public:
2017 explicit StopGradient(Stream stream) : UnaryPrimitive(stream) {}
2018
2019 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2020 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2021
2026
2027 private:
2028 void eval(const std::vector<array>& inputs, array& out);
2029};
2030
2031class Subtract : public UnaryPrimitive {
2032 public:
2033 explicit Subtract(Stream stream) : UnaryPrimitive(stream) {}
2034
2035 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2036 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2037
2043
2044 private:
2045 void eval(const std::vector<array>& inputs, array& out);
2046};
2047
2048class Tan : public UnaryPrimitive {
2049 public:
2050 explicit Tan(Stream stream) : UnaryPrimitive(stream) {}
2051
2052 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2053 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2054
2060
2061 private:
2062 void eval(const std::vector<array>& inputs, array& out);
2063};
2064
2065class Tanh : public UnaryPrimitive {
2066 public:
2067 explicit Tanh(Stream stream) : UnaryPrimitive(stream) {}
2068
2069 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2070 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2071
2077
2078 private:
2079 void eval(const std::vector<array>& inputs, array& out);
2080};
2081
2082class Uniform : public UnaryPrimitive {
2083 public:
2084 explicit Uniform(Stream stream) : UnaryPrimitive(stream) {}
2085
2086 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2087 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2088
2092
2093 private:
2094 void eval(const std::vector<array>& inputs, array& out);
2095};
2096
2097class View : public UnaryPrimitive {
2098 public:
2099 explicit View(Stream stream, Dtype dtype)
2100 : UnaryPrimitive(stream), dtype_(dtype) {}
2101
2102 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2103 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2104
2106 void print(std::ostream& os) override;
2107 bool is_equivalent(const Primitive& other) const override;
2108
2109 private:
2110 Dtype dtype_;
2111};
2112
2114 public:
2115 explicit Transpose(Stream stream, const std::vector<int>& axes)
2116 : UnaryPrimitive(stream), axes_(axes) {}
2117
2118 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2119 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2120
2124 bool is_equivalent(const Primitive& other) const override;
2125
2126 private:
2127 std::vector<int> axes_;
2128
2129 void eval(const std::vector<array>& inputs, array& out);
2130};
2131
2132/* QR Factorization primitive. */
2133class QRF : public Primitive {
2134 public:
2135 explicit QRF(Stream stream) : Primitive(stream) {}
2136
2137 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2138 override;
2139 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2140 override;
2141
2143
2144 private:
2145 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
2146};
2147
2148/* SVD primitive. */
2149class SVD : public Primitive {
2150 public:
2151 explicit SVD(Stream stream) : Primitive(stream) {}
2152
2153 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2154 override;
2155 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2156 override;
2157
2160
2161 private:
2162 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
2163};
2164
2165/* Matrix inversion primitive. */
2166class Inverse : public UnaryPrimitive {
2167 public:
2168 explicit Inverse(Stream stream, bool tri, bool upper)
2169 : UnaryPrimitive(stream), tri_(tri), upper_(upper) {}
2170
2171 void eval_cpu(const std::vector<array>& inputs, array& output) override;
2172 void eval_gpu(const std::vector<array>& inputs, array& output) override;
2173
2176
2177 private:
2178 void eval(const std::vector<array>& inputs, array& output);
2179 bool tri_;
2180 bool upper_;
2181};
2182
2183class Cholesky : public UnaryPrimitive {
2184 public:
2185 explicit Cholesky(Stream stream, bool upper)
2186 : UnaryPrimitive(stream), upper_(upper) {}
2187
2188 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2189 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2190
2193
2194 private:
2195 void eval(const std::vector<array>& inputs, array& output);
2196 bool upper_;
2197};
2198
2199class Eigh : public Primitive {
2200 public:
2201 explicit Eigh(Stream stream, std::string uplo, bool compute_eigenvectors)
2202 : Primitive(stream),
2203 uplo_(std::move(uplo)),
2204 compute_eigenvectors_(compute_eigenvectors) {}
2205
2206 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2207 override;
2208 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2209 override;
2210
2213
2214 std::vector<std::vector<int>> output_shapes(
2215 const std::vector<array>& inputs) override {
2216 auto shape = inputs[0].shape();
2217 shape.pop_back(); // Remove last dimension for eigenvalues
2218 if (compute_eigenvectors_) {
2219 return {shape, inputs[0].shape()}; // Eigenvalues and eigenvectors
2220 } else {
2221 return {shape}; // Only eigenvalues
2222 }
2223 }
2224
2225 bool is_equivalent(const Primitive& other) const override {
2226 if (auto* p = dynamic_cast<const Eigh*>(&other)) {
2227 return uplo_ == p->uplo_ &&
2228 compute_eigenvectors_ == p->compute_eigenvectors_;
2229 }
2230 return false;
2231 }
2232
2233 private:
2234 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
2235 std::string uplo_;
2236 bool compute_eigenvectors_;
2237};
2238
2239} // namespace mlx::core
Definition primitives.h:155
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Abs(Stream stream)
Definition primitives.h:157
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:164
std::vector< std::vector< int > > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:166
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:165
Definition primitives.h:172
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Add(Stream stream)
Definition primitives.h:174
Definition primitives.h:189
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
AddMM(Stream stream, float alpha, float beta)
Definition primitives.h:191
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
Definition primitives.h:213
Arange(Stream stream, double start, double stop, double step)
Definition primitives.h:215
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:232
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCos(Stream stream)
Definition primitives.h:234
Definition primitives.h:249
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCosh(Stream stream)
Definition primitives.h:251
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:266
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcSin(Stream stream)
Definition primitives.h:268
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:283
ArcSinh(Stream stream)
Definition primitives.h:285
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:317
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTan2(Stream stream)
Definition primitives.h:319
Definition primitives.h:300
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcTan(Stream stream)
Definition primitives.h:302
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:334
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTanh(Stream stream)
Definition primitives.h:336
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:351
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArgPartition(Stream stream, int kth, int axis)
Definition primitives.h:353
Definition primitives.h:372
ReduceType
Definition primitives.h:374
@ ArgMin
Definition primitives.h:375
@ ArgMax
Definition primitives.h:376
ArgReduce(Stream stream, ReduceType reduce_type, int axis)
Definition primitives.h:379
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:399
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArgSort(Stream stream, int axis)
Definition primitives.h:401
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:438
AsStrided(Stream stream, std::vector< int > shape, std::vector< size_t > strides, size_t offset)
Definition primitives.h:440
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:418
void eval_gpu(const std::vector< array > &inputs, array &out) override
AsType(Stream stream, Dtype dtype)
Definition primitives.h:420
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:465
BitwiseBinary(Stream stream, Op op)
Definition primitives.h:469
void eval_cpu(const std::vector< array > &inputs, array &out) override
Op
Definition primitives.h:467
@ And
Definition primitives.h:467
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:485
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
BlockMaskedMM(Stream stream, int block_size)
Definition primitives.h:487
Definition primitives.h:528
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Broadcast(Stream stream, const std::vector< int > &shape)
Definition primitives.h:530
Definition primitives.h:547
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Ceil(Stream stream)
Definition primitives.h:549
Definition primitives.h:2183
void eval_cpu(const std::vector< array > &inputs, array &out) override
Cholesky(Stream stream, bool upper)
Definition primitives.h:2185
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:564
Compiled(Stream stream, std::vector< array > inputs, std::vector< array > outputs, std::vector< array > tape, std::unordered_set< uintptr_t > constant_ids)
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:607
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Concatenate(Stream stream, int axis)
Definition primitives.h:609
Definition primitives.h:626
Conjugate(Stream stream)
Definition primitives.h:628
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:642
void eval_gpu(const std::vector< array > &inputs, array &out) override
Convolution(Stream stream, const std::vector< int > &kernel_strides, const std::vector< int > &padding, const std::vector< int > &kernel_dilation, const std::vector< int > &input_dilation, const int groups=1, const bool flip=false)
Definition primitives.h:644
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
Definition primitives.h:683
void eval_gpu(const std::vector< array > &inputs, array &out) override
Copy(Stream stream)
Definition primitives.h:685
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:700
void eval_cpu(const std::vector< array > &inputs, array &out) override
Cos(Stream stream)
Definition primitives.h:702
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:717
void eval_gpu(const std::vector< array > &inputs, array &out) override
Cosh(Stream stream)
Definition primitives.h:719
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:734
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
CustomTransforms(Stream stream, int num_outputs, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> vjp, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< int > &)> jvp, std::function< std::pair< std::vector< array >, std::vector< int > >(const std::vector< array > &, const std::vector< int > &)> vmap)
Definition primitives.h:736
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:786
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotan, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Depends(Stream stream)
Definition primitives.h:788
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:824
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
DivMod(Stream stream)
Definition primitives.h:826
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:807
Divide(Stream stream)
Definition primitives.h:809
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2199
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:2225
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Eigh(Stream stream, std::string uplo, bool compute_eigenvectors)
Definition primitives.h:2201
Definition primitives.h:880
Equal(Stream stream, bool equal_nan=false)
Definition primitives.h:882
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:906
Erf(Stream stream)
Definition primitives.h:908
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:923
void eval_gpu(const std::vector< array > &inputs, array &out) override
ErfInv(Stream stream)
Definition primitives.h:925
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:940
Exp(Stream stream)
Definition primitives.h:942
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:957
Expm1(Stream stream)
Definition primitives.h:959
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:973
FFT(Stream stream, const std::vector< size_t > &axes, bool inverse, bool real)
Definition primitives.h:975
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:999
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Floor(Stream stream)
Definition primitives.h:1001
Definition primitives.h:1016
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Full(Stream stream)
Definition primitives.h:1018
Definition primitives.h:1032
Gather(Stream stream, const std::vector< int > &axes, const std::vector< int > &slice_sizes)
Definition primitives.h:1034
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:508
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, array &out) override
GatherMM(Stream stream)
Definition primitives.h:510
Definition primitives.h:1535
GatherQMM(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1537
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1071
void eval_cpu(const std::vector< array > &inputs, array &out) override
GreaterEqual(Stream stream)
Definition primitives.h:1073
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1054
Greater(Stream stream)
Definition primitives.h:1056
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1088
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Hadamard(Stream stream, float scale)
Definition primitives.h:1090
Definition primitives.h:1109
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Imag(Stream stream)
Definition primitives.h:1111
Definition primitives.h:2166
void eval_gpu(const std::vector< array > &inputs, array &output) override
Inverse(Stream stream, bool tri, bool upper)
Definition primitives.h:2168
void eval_cpu(const std::vector< array > &inputs, array &output) override
Definition primitives.h:1140
LessEqual(Stream stream)
Definition primitives.h:1142
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1123
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Less(Stream stream)
Definition primitives.h:1125
Definition primitives.h:1157
void eval_gpu(const std::vector< array > &inputs, array &out) override
Load(Stream stream, std::shared_ptr< io::Reader > reader, size_t offset, bool swap_endianness=false)
Definition primitives.h:1159
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1223
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Log1p(Stream stream)
Definition primitives.h:1225
Definition primitives.h:1290
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogAddExp(Stream stream)
Definition primitives.h:1292
Definition primitives.h:1189
Base
Definition primitives.h:1191
Log(Stream stream, Base base)
Definition primitives.h:1193
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1256
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalAnd(Stream stream)
Definition primitives.h:1258
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1239
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalNot(Stream stream)
Definition primitives.h:1241
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1273
void eval_cpu(const std::vector< array > &inputs, array &out) override
LogicalOr(Stream stream)
Definition primitives.h:1275
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1307
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, array &out) override
Matmul(Stream stream)
Definition primitives.h:1309
Definition primitives.h:1325
Maximum(Stream stream)
Definition primitives.h:1327
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1342
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Minimum(Stream stream)
Definition primitives.h:1344
Definition primitives.h:1359
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Multiply(Stream stream)
Definition primitives.h:1361
Definition primitives.h:1376
void eval_gpu(const std::vector< array > &inputs, array &out) override
Negative(Stream stream)
Definition primitives.h:1378
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1393
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
NotEqual(Stream stream)
Definition primitives.h:1395
Definition primitives.h:1410
void eval_gpu(const std::vector< array > &inputs, array &out) override
NumberOfElements(Stream stream, std::vector< int > axes, bool inverted, Dtype dtype)
Definition primitives.h:1412
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1441
void eval_cpu(const std::vector< array > &inputs, array &out) override
Pad(Stream stream, const std::vector< int > &axes, const std::vector< int > &low_pad_size, const std::vector< int > &high_pad_size)
Definition primitives.h:1443
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1469
void eval_cpu(const std::vector< array > &inputs, array &out) override
Partition(Stream stream, int kth, int axis)
Definition primitives.h:1471
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1490
void eval_cpu(const std::vector< array > &inputs, array &out) override
Power(Stream stream)
Definition primitives.h:1492
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:48
virtual void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
virtual std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
The vector-Jacobian product.
virtual ~Primitive()=default
Primitive(const Primitive &other)=delete
Primitive(Primitive &&other)=delete
const Stream & stream()
The stream the primitive will run on.
Definition primitives.h:58
Primitive & operator=(Primitive &&other)=delete
virtual bool is_equivalent(const Primitive &other) const
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:107
Primitive & operator=(const Primitive &other)=delete
virtual std::vector< std::vector< int > > output_shapes(const std::vector< array > &inputs)
Get the output shapes of the primitive.
const Device & device()
The device the primitive will run on.
Definition primitives.h:53
virtual std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
The Jacobian-vector product.
virtual std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes)
The primitive must know how to vectorize itself across the given axes.
virtual void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
virtual void print(std::ostream &os)=0
Print the primitive.
Primitive(Stream stream)
Definition primitives.h:50
Definition primitives.h:2133
QRF(Stream stream)
Definition primitives.h:2135
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:1507
void eval_gpu(const std::vector< array > &inputs, array &out) override
QuantizedMatmul(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1509
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1559
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
RandomBits(Stream stream, const std::vector< int > &shape, int width)
Definition primitives.h:1561
Definition primitives.h:1578
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Real(Stream stream)
Definition primitives.h:1580
Definition primitives.h:1619
Reduce(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1623
ReduceType
Definition primitives.h:1621
@ And
Definition primitives.h:1621
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:863
Remainder(Stream stream)
Definition primitives.h:865
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1592
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Reshape(Stream stream, const std::vector< int > &shape)
Definition primitives.h:1594
Definition primitives.h:1674
Round(Stream stream)
Definition primitives.h:1676
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2149
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
SVD(Stream stream)
Definition primitives.h:2151
Definition primitives.h:1691
void eval_cpu(const std::vector< array > &inputs, array &out) override
ReduceType
Definition primitives.h:1693
@ Max
Definition primitives.h:1693
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Scan(Stream stream, ReduceType reduce_type, int axis, bool reverse, bool inclusive)
Definition primitives.h:1695
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1741
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
ReduceType
Definition primitives.h:1743
@ Max
Definition primitives.h:1743
void eval_cpu(const std::vector< array > &inputs, array &out) override
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1757
void eval_gpu(const std::vector< array > &inputs, array &out) override
Scatter(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1745
Definition primitives.h:846
void eval_gpu(const std::vector< array > &inputs, array &out) override
Select(Stream stream)
Definition primitives.h:848
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1784
Sigmoid(Stream stream)
Definition primitives.h:1786
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1801
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Sign(Stream stream)
Definition primitives.h:1803
Definition primitives.h:1818
Sin(Stream stream)
Definition primitives.h:1820
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1835
Sinh(Stream stream)
Definition primitives.h:1837
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1852
void eval_cpu(const std::vector< array > &inputs, array &out) override
Slice(Stream stream, const std::vector< int > &start_indices, const std::vector< int > &end_indices, const std::vector< int > &strides)
Definition primitives.h:1854
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1880
SliceUpdate(Stream stream, const std::vector< int > &start_indices, const std::vector< int > &end_indices, const std::vector< int > &strides)
Definition primitives.h:1882
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1910
void eval_gpu(const std::vector< array > &inputs, array &out) override
Softmax(Stream stream, bool precise)
Definition primitives.h:1912
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1930
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sort(Stream stream, int axis)
Definition primitives.h:1932
Definition primitives.h:1950
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Split(Stream stream, const std::vector< int > &indices, int axis)
Definition primitives.h:1952
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:1989
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sqrt(Stream stream, bool recip=false)
Definition primitives.h:1991
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1972
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Square(Stream stream)
Definition primitives.h:1974
Definition primitives.h:2015
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
StopGradient(Stream stream)
Definition primitives.h:2017
Definition primitives.h:2031
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Subtract(Stream stream)
Definition primitives.h:2033
Definition primitives.h:2048
Tan(Stream stream)
Definition primitives.h:2050
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2065
void eval_gpu(const std::vector< array > &inputs, array &out) override
Tanh(Stream stream)
Definition primitives.h:2067
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2113
Transpose(Stream stream, const std::vector< int > &axes)
Definition primitives.h:2115
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:127
UnaryPrimitive & operator=(const UnaryPrimitive &other)=delete
UnaryPrimitive(Stream stream)
An abstract base class for a primitive with a single output.
Definition primitives.h:132
virtual void eval_gpu(const std::vector< array > &inputs, array &output)=0
UnaryPrimitive(UnaryPrimitive &&other)=delete
virtual void eval_cpu(const std::vector< array > &inputs, array &output)=0
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:142
UnaryPrimitive(const UnaryPrimitive &other)=delete
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:137
UnaryPrimitive & operator=(UnaryPrimitive &&other)=delete
virtual ~UnaryPrimitive()=default
Definition primitives.h:2082
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Uniform(Stream stream)
Definition primitives.h:2084
Definition primitives.h:2097
void eval_cpu(const std::vector< array > &inputs, array &out) override
View(Stream stream, Dtype dtype)
Definition primitives.h:2099
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition array.h:20
Op op
Definition binary.h:129
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
array tri(int n, int m, int k, Dtype type, StreamOrDevice s={})
array transpose(const array &a, std::vector< int > axes, StreamOrDevice s={})
Permutes the dimensions according to the given axes.
array real(const array &a, StreamOrDevice s={})
Definition allocator.h:7
std::pair< std::vector< array >, std::vector< array > > jvp(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &tangents)
Computes the output and Jacobian-vector product (JVP) of a function.
std::pair< std::vector< array >, std::vector< array > > vjp(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &cotangents)
Computes the output and vector-Jacobian product (VJP) of a function.
void eval(std::vector< array > outputs)
std::function< array(const array &)> vmap(const std::function< array(const array &)> &fun, int in_axis=0, int out_axis=0)
Automatically vectorize a unary function over the requested axes.
#define DEFINE_DEFAULT_IS_EQUIVALENT()
Definition primitives.h:34
#define DEFINE_PRINT(PRIMITIVE)
Definition primitives.h:29
#define DEFINE_INPUT_OUTPUT_SHAPE()
Definition primitives.h:39
#define DEFINE_GRADS()
Definition primitives.h:17
#define DEFINE_VMAP()
Definition primitives.h:12
Definition ops.h:37
Definition binary_ops.h:270
Definition ops.h:185
Definition ops.h:163
Definition ops.h:29
Definition ops.h:78
Definition ops.h:141
Definition binary_ops.h:277
Definition ops.h:119
Definition device.h:7
Definition dtype.h:13
Definition stream.h:9
Device device
Definition stream.h:11