MLX
Loading...
Searching...
No Matches
primitives.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <unordered_set>
6
7#include "mlx/array.h"
8#include "mlx/device.h"
9#include "mlx/io/load.h"
10#include "mlx/stream.h"
11
12#define DEFINE_VMAP() \
13 virtual std::pair<std::vector<array>, std::vector<int>> vmap( \
14 const std::vector<array>& inputs, const std::vector<int>& axes) \
15 override;
16
17#define DEFINE_GRADS() \
18 std::vector<array> jvp( \
19 const std::vector<array>& primals, \
20 const std::vector<array>& tangents, \
21 const std::vector<int>& argnums) override; \
22 \
23 std::vector<array> vjp( \
24 const std::vector<array>& primals, \
25 const std::vector<array>& cotangents, \
26 const std::vector<int>& argnums, \
27 const std::vector<array>& outputs) override;
28
29#define DEFINE_PRINT(PRIMITIVE) \
30 void print(std::ostream& os) override { \
31 os << #PRIMITIVE; \
32 }
33
34#define DEFINE_DEFAULT_IS_EQUIVALENT() \
35 bool is_equivalent(const Primitive& other) const override { \
36 return true; \
37 }
38
39#define DEFINE_INPUT_OUTPUT_SHAPE() \
40 std::vector<Shape> output_shapes(const std::vector<array>& inputs) \
41 override { \
42 return {inputs[0].shape()}; \
43 }
44
45namespace mlx::core {
46
47// Abstract base class
48class Primitive {
49 public:
50 explicit Primitive(Stream stream) : stream_(stream) {}
51
53 const Device& device() {
54 return stream().device;
55 }
56
58 const Stream& stream() {
59 return stream_;
60 }
61
69 virtual void eval_cpu(
70 const std::vector<array>& inputs,
71 std::vector<array>& outputs) = 0;
72 virtual void eval_gpu(
73 const std::vector<array>& inputs,
74 std::vector<array>& outputs) = 0;
75
79 virtual std::vector<array> jvp(
80 const std::vector<array>& primals,
81 const std::vector<array>& tangents,
82 const std::vector<int>& argnums);
83
87 virtual std::vector<array> vjp(
88 const std::vector<array>& primals,
89 const std::vector<array>& cotangents,
90 const std::vector<int>& argnums,
91 const std::vector<array>& outputs);
92
99 virtual std::pair<std::vector<array>, std::vector<int>> vmap(
100 const std::vector<array>& inputs,
101 const std::vector<int>& axes);
102
104 virtual void print(std::ostream& os) = 0;
105
107 virtual bool is_equivalent(const Primitive& other) const {
108 return false;
109 }
110
113 virtual std::vector<Shape> output_shapes(const std::vector<array>& inputs);
114
115 virtual ~Primitive() = default;
116 Primitive(const Primitive& other) = delete;
117 Primitive(Primitive&& other) = delete;
118 Primitive& operator=(const Primitive& other) = delete;
119 Primitive& operator=(Primitive&& other) = delete;
120
121 private:
122 // Every primitive stores the stream it should run in
123 Stream stream_;
124};
125
126class UnaryPrimitive : public Primitive {
130 public:
132
133 virtual void eval_cpu(const std::vector<array>& inputs, array& output) = 0;
134 virtual void eval_gpu(const std::vector<array>& inputs, array& output) = 0;
135
136 inline void eval_cpu(
137 const std::vector<array>& inputs,
138 std::vector<array>& outputs) override {
139 eval_cpu(inputs, outputs[0]);
140 }
141 inline void eval_gpu(
142 const std::vector<array>& inputs,
143 std::vector<array>& outputs) override {
144 eval_gpu(inputs, outputs[0]);
145 }
146
147 virtual ~UnaryPrimitive() = default;
148 UnaryPrimitive(const UnaryPrimitive& other) = delete;
150 UnaryPrimitive& operator=(const UnaryPrimitive& other) = delete;
152};
153
154class Abs : public UnaryPrimitive {
155 public:
157
158 void eval_cpu(const std::vector<array>& inputs, array& out) override;
159 void eval_gpu(const std::vector<array>& inputs, array& out) override;
160
166
167 private:
168 void eval(const std::vector<array>& inputs, array& out);
169};
170
171class Add : public UnaryPrimitive {
172 public:
174
175 void eval_cpu(const std::vector<array>& inputs, array& out) override;
176 void eval_gpu(const std::vector<array>& inputs, array& out) override;
177
183
184 private:
185 void eval(const std::vector<array>& inputs, array& out);
186};
187
188class AddMM : public UnaryPrimitive {
189 public:
190 explicit AddMM(Stream stream, float alpha, float beta)
191 : UnaryPrimitive(stream), alpha_(alpha), beta_(beta) {}
192
193 void eval_cpu(const std::vector<array>& inputs, array& out) override;
194 void eval_gpu(const std::vector<array>& inputs, array& out) override;
195
196 std::vector<array> vjp(
197 const std::vector<array>& primals,
198 const std::vector<array>& cotangents,
199 const std::vector<int>& argnums,
200 const std::vector<array>& outputs) override;
201
204
205 bool is_equivalent(const Primitive& other) const override;
206
207 private:
208 const float alpha_;
209 const float beta_;
210};
211
212class Arange : public UnaryPrimitive {
213 public:
214 explicit Arange(Stream stream, double start, double stop, double step)
215 : UnaryPrimitive(stream), start_(start), stop_(stop), step_(step) {}
216
217 void eval_cpu(const std::vector<array>& inputs, array& out) override;
218 void eval_gpu(const std::vector<array>& inputs, array& out) override;
219
221 bool is_equivalent(const Primitive& other) const override;
222 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
223
224 private:
225 double start_;
226 double stop_;
227 double step_;
228
229 void eval(const std::vector<array>& inputs, array& out);
230};
231
232class ArcCos : public UnaryPrimitive {
233 public:
235
236 void eval_cpu(const std::vector<array>& inputs, array& out) override;
237 void eval_gpu(const std::vector<array>& inputs, array& out) override;
238
244
245 private:
246 void eval(const std::vector<array>& inputs, array& out);
247};
248
249class ArcCosh : public UnaryPrimitive {
250 public:
252
253 void eval_cpu(const std::vector<array>& inputs, array& out) override;
254 void eval_gpu(const std::vector<array>& inputs, array& out) override;
255
261
262 private:
263 void eval(const std::vector<array>& inputs, array& out);
264};
265
266class ArcSin : public UnaryPrimitive {
267 public:
269
270 void eval_cpu(const std::vector<array>& inputs, array& out) override;
271 void eval_gpu(const std::vector<array>& inputs, array& out) override;
272
278
279 private:
280 void eval(const std::vector<array>& inputs, array& out);
281};
282
283class ArcSinh : public UnaryPrimitive {
284 public:
286
287 void eval_cpu(const std::vector<array>& inputs, array& out) override;
288 void eval_gpu(const std::vector<array>& inputs, array& out) override;
289
295
296 private:
297 void eval(const std::vector<array>& inputs, array& out);
298};
299
300class ArcTan : public UnaryPrimitive {
301 public:
303
304 void eval_cpu(const std::vector<array>& inputs, array& out) override;
305 void eval_gpu(const std::vector<array>& inputs, array& out) override;
306
312
313 private:
314 void eval(const std::vector<array>& inputs, array& out);
315};
316
317class ArcTan2 : public UnaryPrimitive {
318 public:
320
321 void eval_cpu(const std::vector<array>& inputs, array& out) override;
322 void eval_gpu(const std::vector<array>& inputs, array& out) override;
323
329
330 private:
331 void eval(const std::vector<array>& inputs, array& out);
332};
333
334class ArcTanh : public UnaryPrimitive {
335 public:
337
338 void eval_cpu(const std::vector<array>& inputs, array& out) override;
339 void eval_gpu(const std::vector<array>& inputs, array& out) override;
340
346
347 private:
348 void eval(const std::vector<array>& inputs, array& out);
349};
350
352 public:
353 explicit ArgPartition(Stream stream, int kth, int axis)
354 : UnaryPrimitive(stream), kth_(kth), axis_(axis) {}
355
356 void eval_cpu(const std::vector<array>& inputs, array& out) override;
357 void eval_gpu(const std::vector<array>& inputs, array& out) override;
358
363 bool is_equivalent(const Primitive& other) const override;
364
365 private:
366 int kth_;
367 int axis_;
368
369 void eval(const std::vector<array>& inputs, array& out);
370};
371
372class ArgReduce : public UnaryPrimitive {
373 public:
378
379 explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis)
380 : UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {}
381
382 void eval_cpu(const std::vector<array>& inputs, array& out) override;
383 void eval_gpu(const std::vector<array>& inputs, array& out) override;
384
388 bool is_equivalent(const Primitive& other) const override;
389 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
390
391 private:
392 ReduceType reduce_type_;
393 int axis_;
394
395 void eval(const std::vector<array>& inputs, array& out);
396};
397
398class ArgSort : public UnaryPrimitive {
399 public:
400 explicit ArgSort(Stream stream, int axis)
401 : UnaryPrimitive(stream), axis_(axis) {}
402
403 void eval_cpu(const std::vector<array>& inputs, array& out) override;
404 void eval_gpu(const std::vector<array>& inputs, array& out) override;
405
409 bool is_equivalent(const Primitive& other) const override;
410
411 private:
412 int axis_;
413
414 void eval(const std::vector<array>& inputs, array& out);
415};
416
417class AsType : public UnaryPrimitive {
418 public:
419 explicit AsType(Stream stream, Dtype dtype)
420 : UnaryPrimitive(stream), dtype_(dtype) {}
421
422 void eval_cpu(const std::vector<array>& inputs, array& out) override;
423 void eval_gpu(const std::vector<array>& inputs, array& out) override;
424
429 bool is_equivalent(const Primitive& other) const override;
430
431 private:
432 Dtype dtype_;
433
434 void eval(const std::vector<array>& inputs, array& out);
435};
436
437class AsStrided : public UnaryPrimitive {
438 public:
439 explicit AsStrided(Stream stream, Shape shape, Strides strides, size_t offset)
441 shape_(std::move(shape)),
442 strides_(std::move(strides)),
443 offset_(offset) {}
444
445 void eval_cpu(const std::vector<array>& inputs, array& out) override;
446 void eval_gpu(const std::vector<array>& inputs, array& out) override;
447
450 bool is_equivalent(const Primitive& other) const override;
451
452 private:
453 Shape shape_;
454 Strides strides_;
455 size_t offset_;
456
457 void eval(const std::vector<array>& inputs, array& out);
458};
459
461 public:
462 enum Op { And, Or, Xor, LeftShift, RightShift };
463
466
467 void eval_cpu(const std::vector<array>& inputs, array& out) override;
468 void eval_gpu(const std::vector<array>& inputs, array& out) override;
469
472 bool is_equivalent(const Primitive& other) const override;
473 void print(std::ostream& os) override;
475
476 private:
477 Op op_;
478};
479
481 public:
482 explicit BlockMaskedMM(Stream stream, int block_size)
483 : UnaryPrimitive(stream), block_size_(block_size) {}
484
485 void eval_cpu(const std::vector<array>& inputs, array& out) override;
486 void eval_gpu(const std::vector<array>& inputs, array& out) override;
487
488 std::vector<array> vjp(
489 const std::vector<array>& primals,
490 const std::vector<array>& cotangents,
491 const std::vector<int>& argnums,
492 const std::vector<array>& outputs) override;
493
495 bool is_equivalent(const Primitive& other) const override;
496
497 private:
498 int block_size_;
499
500 void eval(const std::vector<array>& inputs, array& out);
501};
502
503class GatherMM : public UnaryPrimitive {
504 public:
506
507 void eval_cpu(const std::vector<array>& inputs, array& out) override;
508 void eval_gpu(const std::vector<array>& inputs, array& out) override;
509
510 std::vector<array> vjp(
511 const std::vector<array>& primals,
512 const std::vector<array>& cotangents,
513 const std::vector<int>& argnums,
514 const std::vector<array>& outputs) override;
515
518
519 private:
520 void eval(const std::vector<array>& inputs, array& out);
521};
522
523class Broadcast : public UnaryPrimitive {
524 public:
525 explicit Broadcast(Stream stream, const Shape& shape)
526 : UnaryPrimitive(stream), shape_(shape) {}
527
528 void eval_cpu(const std::vector<array>& inputs, array& out) override;
529 void eval_gpu(const std::vector<array>& inputs, array& out) override;
530
534 bool is_equivalent(const Primitive& other) const override;
535
536 private:
537 Shape shape_;
538
539 void eval(const std::vector<array>& inputs, array& out);
540};
541
542class Ceil : public UnaryPrimitive {
543 public:
545
546 void eval_cpu(const std::vector<array>& inputs, array& out) override;
547 void eval_gpu(const std::vector<array>& inputs, array& out) override;
548
554
555 private:
556 void eval(const std::vector<array>& inputs, array& out);
557};
558
559class Compiled : public Primitive {
560 public:
561 /*
562 * The inputs, outputs and tape are either tracers or constants.
563 * - The tape should not contain the inputs, but it should contain the
564 * outputs.
565 * - The tape should also have only one array per primitive for multi-output
566 * primitives.
567 * - The constant_ids contains ids of arrays in the input list that are safe
568 * to treat as scalar constants.
569 */
570 explicit Compiled(
572 std::vector<array> inputs,
573 std::vector<array> outputs,
574 std::vector<array> tape,
575 std::unordered_set<uintptr_t> constant_ids);
576
577 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
578 override;
579 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
580 override;
581
584 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
585 void print(std::ostream& os) override;
586 bool is_equivalent(const Primitive& other) const override;
587
588 std::string lib_name() const {
589 return kernel_lib_;
590 }
591
592 private:
593 const std::vector<array> inputs_;
594 const std::vector<array> outputs_;
595 const std::vector<array> tape_;
596 const std::unordered_set<uintptr_t> constant_ids_;
597
598 std::string kernel_lib_;
599};
600
602 public:
603 explicit Concatenate(Stream stream, int axis)
604 : UnaryPrimitive(stream), axis_(axis) {}
605
606 void eval_cpu(const std::vector<array>& inputs, array& out) override;
607 void eval_gpu(const std::vector<array>& inputs, array& out) override;
608
612 bool is_equivalent(const Primitive& other) const override;
613 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
614
615 private:
616 int axis_;
617
618 void eval(const std::vector<array>& inputs, array& out);
619};
620
621class Conjugate : public UnaryPrimitive {
622 public:
623 explicit Conjugate(Stream stream) : UnaryPrimitive(stream) {}
624
625 void eval_cpu(const std::vector<array>& inputs, array& out) override;
626 void eval_gpu(const std::vector<array>& inputs, array& out) override;
627
632
633 private:
634 void eval(const std::vector<array>& inputs, array& out);
635};
636
638 public:
639 explicit Contiguous(Stream stream, bool allow_col_major)
640 : UnaryPrimitive(stream), allow_col_major_(allow_col_major) {}
641
642 void eval_cpu(const std::vector<array>& inputs, array& out) override;
643 void eval_gpu(const std::vector<array>& inputs, array& out) override;
644
649
650 bool is_equivalent(const Primitive& other) const override;
651
652 private:
653 bool allow_col_major_;
654};
655
657 public:
658 explicit Convolution(
659 Stream stream,
660 const std::vector<int>& kernel_strides,
661 const std::vector<int>& padding,
662 const std::vector<int>& kernel_dilation,
663 const std::vector<int>& input_dilation,
664 const int groups = 1,
665 const bool flip = false)
666 : UnaryPrimitive(stream),
667 padding_(padding),
668 kernel_strides_(kernel_strides),
669 kernel_dilation_(kernel_dilation),
670 input_dilation_(input_dilation),
671 groups_(groups),
672 flip_(flip) {}
673
674 void eval_cpu(const std::vector<array>& inputs, array& out) override;
675 void eval_gpu(const std::vector<array>& inputs, array& out) override;
676
677 std::vector<array> vjp(
678 const std::vector<array>& primals,
679 const std::vector<array>& cotangents,
680 const std::vector<int>& argnums,
681 const std::vector<array>& outputs) override;
682
684 bool is_equivalent(const Primitive& other) const override;
685
686 private:
687 std::vector<int> padding_;
688 std::vector<int> kernel_strides_;
689 std::vector<int> kernel_dilation_;
690 std::vector<int> input_dilation_;
691 int groups_;
692 bool flip_;
693
694 void eval(const std::vector<array>& inputs, array& out);
695};
696
697class Copy : public UnaryPrimitive {
698 public:
699 explicit Copy(Stream stream) : UnaryPrimitive(stream) {}
700
701 void eval_cpu(const std::vector<array>& inputs, array& out) override;
702 void eval_gpu(const std::vector<array>& inputs, array& out) override;
703
709
710 private:
711 void eval(const std::vector<array>& inputs, array& out);
712};
713
714class Cos : public UnaryPrimitive {
715 public:
716 explicit Cos(Stream stream) : UnaryPrimitive(stream) {}
717
718 void eval_cpu(const std::vector<array>& inputs, array& out) override;
719 void eval_gpu(const std::vector<array>& inputs, array& out) override;
720
726
727 private:
728 void eval(const std::vector<array>& inputs, array& out);
729};
730
731class Cosh : public UnaryPrimitive {
732 public:
733 explicit Cosh(Stream stream) : UnaryPrimitive(stream) {}
734
735 void eval_cpu(const std::vector<array>& inputs, array& out) override;
736 void eval_gpu(const std::vector<array>& inputs, array& out) override;
737
743
744 private:
745 void eval(const std::vector<array>& inputs, array& out);
746};
747
749 public:
751 Stream stream,
752 int num_outputs,
753 std::function<std::vector<array>(
754 const std::vector<array>&,
755 const std::vector<array>&,
756 const std::vector<array>&)> vjp,
757 std::function<std::vector<array>(
758 const std::vector<array>&,
759 const std::vector<array>&,
760 const std::vector<int>&)> jvp,
761 std::function<std::pair<std::vector<array>, std::vector<int>>(
762 const std::vector<array>&,
763 const std::vector<int>&)> vmap)
764 : Primitive(stream),
765 num_outputs_(num_outputs),
766 vjp_fun_(std::move(vjp)),
767 jvp_fun_(std::move(jvp)),
768 vmap_fun_(std::move(vmap)) {}
769
770 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
771 override;
772 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
773 override;
774
778
779 private:
780 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
781
782 int num_outputs_;
783
784 std::function<std::vector<array>(
785 const std::vector<array>&,
786 const std::vector<array>&,
787 const std::vector<array>&)>
788 vjp_fun_;
789 std::function<std::vector<array>(
790 const std::vector<array>&,
791 const std::vector<array>&,
792 const std::vector<int>&)>
793 jvp_fun_;
794 std::function<std::pair<std::vector<array>, std::vector<int>>(
795 const std::vector<array>&,
796 const std::vector<int>&)>
797 vmap_fun_;
798};
799
800class Depends : public Primitive {
801 public:
802 explicit Depends(Stream stream) : Primitive(stream) {}
803
804 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
805 override;
806 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
807 override;
808
809 std::vector<array> vjp(
810 const std::vector<array>& primals,
811 const std::vector<array>& cotan,
812 const std::vector<int>& argnums,
813 const std::vector<array>& outputs) override;
814
816
817 private:
818 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
819};
820
821class Divide : public UnaryPrimitive {
822 public:
823 explicit Divide(Stream stream) : UnaryPrimitive(stream) {}
824
825 void eval_cpu(const std::vector<array>& inputs, array& out) override;
826 void eval_gpu(const std::vector<array>& inputs, array& out) override;
827
833
834 private:
835 void eval(const std::vector<array>& inputs, array& out);
836};
837
838class DivMod : public Primitive {
839 public:
840 explicit DivMod(Stream stream) : Primitive(stream) {}
841
842 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
843 override;
844 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
845 override;
846
851 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
852 return std::vector{inputs[0].shape(), inputs[0].shape()};
853 }
854
855 private:
856 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
857};
858
859class Select : public UnaryPrimitive {
860 public:
861 explicit Select(Stream stream) : UnaryPrimitive(stream) {}
862
863 void eval_cpu(const std::vector<array>& inputs, array& out) override;
864 void eval_gpu(const std::vector<array>& inputs, array& out) override;
865
871
872 private:
873 void eval(const std::vector<array>& inputs, array& out);
874};
875
876class Remainder : public UnaryPrimitive {
877 public:
878 explicit Remainder(Stream stream) : UnaryPrimitive(stream) {}
879
880 void eval_cpu(const std::vector<array>& inputs, array& out) override;
881 void eval_gpu(const std::vector<array>& inputs, array& out) override;
882
888
889 private:
890 void eval(const std::vector<array>& inputs, array& out);
891};
892
893class Equal : public UnaryPrimitive {
894 public:
895 explicit Equal(Stream stream, bool equal_nan = false)
896 : UnaryPrimitive(stream), equal_nan_(equal_nan) {}
897
898 void eval_cpu(const std::vector<array>& inputs, array& out) override;
899 void eval_gpu(const std::vector<array>& inputs, array& out) override;
900
905
906 void print(std::ostream& os) override {
907 if (equal_nan_) {
908 os << "NaNEqual";
909 } else {
910 os << "Equal";
911 }
912 }
913
914 private:
915 void eval(const std::vector<array>& inputs, array& out);
916 bool equal_nan_;
917};
918
919class Erf : public UnaryPrimitive {
920 public:
921 explicit Erf(Stream stream) : UnaryPrimitive(stream) {}
922
923 void eval_cpu(const std::vector<array>& inputs, array& out) override;
924 void eval_gpu(const std::vector<array>& inputs, array& out) override;
925
931
932 private:
933 void eval(const std::vector<array>& inputs, array& out);
934};
935
936class ErfInv : public UnaryPrimitive {
937 public:
938 explicit ErfInv(Stream stream) : UnaryPrimitive(stream) {}
939
940 void eval_cpu(const std::vector<array>& inputs, array& out) override;
941 void eval_gpu(const std::vector<array>& inputs, array& out) override;
942
948
949 private:
950 void eval(const std::vector<array>& inputs, array& out);
951};
952
953class Exp : public UnaryPrimitive {
954 public:
955 explicit Exp(Stream stream) : UnaryPrimitive(stream) {}
956
957 void eval_cpu(const std::vector<array>& inputs, array& out) override;
958 void eval_gpu(const std::vector<array>& inputs, array& out) override;
959
965
966 private:
967 void eval(const std::vector<array>& inputs, array& out);
968};
969
970class Expm1 : public UnaryPrimitive {
971 public:
972 explicit Expm1(Stream stream) : UnaryPrimitive(stream) {}
973
974 void eval_cpu(const std::vector<array>& inputs, array& out) override;
975 void eval_gpu(const std::vector<array>& inputs, array& out) override;
976
981
982 private:
983 void eval(const std::vector<array>& inputs, array& out);
984};
985
986class FFT : public UnaryPrimitive {
987 public:
988 explicit FFT(
989 Stream stream,
990 const std::vector<size_t>& axes,
991 bool inverse,
992 bool real)
993 : UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {}
994
995 void eval_cpu(const std::vector<array>& inputs, array& out) override;
996 void eval_gpu(const std::vector<array>& inputs, array& out) override;
997
1001
1002 bool is_equivalent(const Primitive& other) const override;
1003
1004 private:
1005 std::vector<size_t> axes_;
1006 bool inverse_;
1007 bool real_;
1008
1009 void eval(const std::vector<array>& inputs, array& out);
1010};
1011
1012class Floor : public UnaryPrimitive {
1013 public:
1014 explicit Floor(Stream stream) : UnaryPrimitive(stream) {}
1015
1016 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1017 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1018
1024
1025 private:
1026 void eval(const std::vector<array>& inputs, array& out);
1027};
1028
1029class Full : public UnaryPrimitive {
1030 public:
1031 explicit Full(Stream stream) : UnaryPrimitive(stream) {}
1032
1033 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1034 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1035
1040
1041 private:
1042 void eval(const std::vector<array>& inputs, array& out);
1043};
1044
1045class Gather : public UnaryPrimitive {
1046 public:
1047 explicit Gather(
1048 Stream stream,
1049 const std::vector<int>& axes,
1050 const std::vector<int>& slice_sizes)
1051 : UnaryPrimitive(stream), axes_(axes), slice_sizes_(slice_sizes) {}
1052
1053 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1054 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1055
1059 bool is_equivalent(const Primitive& other) const override;
1060 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1061
1062 private:
1063 void eval(const std::vector<array>& inputs, array& out);
1064 std::vector<int> axes_;
1065 std::vector<int> slice_sizes_;
1066};
1067
1068class Greater : public UnaryPrimitive {
1069 public:
1070 explicit Greater(Stream stream) : UnaryPrimitive(stream) {}
1071
1072 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1073 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1074
1080
1081 private:
1082 void eval(const std::vector<array>& inputs, array& out);
1083};
1084
1086 public:
1087 explicit GreaterEqual(Stream stream) : UnaryPrimitive(stream) {}
1088
1089 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1090 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1091
1097
1098 private:
1099 void eval(const std::vector<array>& inputs, array& out);
1100};
1101
1102class Hadamard : public UnaryPrimitive {
1103 public:
1104 explicit Hadamard(Stream stream, float scale)
1105 : UnaryPrimitive(stream), scale_(scale) {}
1106
1107 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1108 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1109
1114
1115 bool is_equivalent(const Primitive& other) const override;
1116
1117 private:
1118 float scale_;
1119
1120 void eval(const std::vector<array>& inputs, array& out);
1121};
1122
1123class Imag : public UnaryPrimitive {
1124 public:
1125 explicit Imag(Stream stream) : UnaryPrimitive(stream) {}
1126
1127 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1128 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1129
1135};
1136
1137class Less : public UnaryPrimitive {
1138 public:
1139 explicit Less(Stream stream) : UnaryPrimitive(stream) {}
1140
1141 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1142 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1143
1149
1150 private:
1151 void eval(const std::vector<array>& inputs, array& out);
1152};
1153
1155 public:
1156 explicit LessEqual(Stream stream) : UnaryPrimitive(stream) {}
1157
1158 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1159 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1160
1166
1167 private:
1168 void eval(const std::vector<array>& inputs, array& out);
1169};
1170
1171class Load : public UnaryPrimitive {
1172 public:
1173 explicit Load(
1174 Stream stream,
1175 std::shared_ptr<io::Reader> reader,
1176 size_t offset,
1177 bool swap_endianness = false)
1178 : UnaryPrimitive(stream),
1179 reader_(std::move(reader)),
1180 offset_(offset),
1181 swap_endianness_(swap_endianness) {
1182 if (stream.device == Device::gpu) {
1183 io_stream();
1184 }
1185 }
1186
1187 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1188 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1189
1191
1192 private:
1193 Stream& io_stream() {
1194 static Stream io_stream = new_stream(Device::cpu);
1195 return io_stream;
1196 };
1197 void eval(const std::vector<array>& inputs, array& out);
1198 std::shared_ptr<io::Reader> reader_;
1199 size_t offset_;
1200 bool swap_endianness_;
1201};
1202
1203class Log : public UnaryPrimitive {
1204 public:
1205 enum Base { two, ten, e };
1206
1207 explicit Log(Stream stream, Base base)
1208 : UnaryPrimitive(stream), base_(base) {}
1209
1210 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1211 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1212
1217
1218 void print(std::ostream& os) override {
1219 switch (base_) {
1220 case e:
1221 os << "Log";
1222 break;
1223 case two:
1224 os << "Log2";
1225 break;
1226 case ten:
1227 os << "Log10";
1228 break;
1229 }
1230 }
1231
1232 private:
1233 Base base_;
1234 void eval(const std::vector<array>& inputs, array& out);
1235};
1236
1237class Log1p : public UnaryPrimitive {
1238 public:
1239 explicit Log1p(Stream stream) : UnaryPrimitive(stream) {}
1240
1241 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1242 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1243
1248
1249 private:
1250 void eval(const std::vector<array>& inputs, array& out);
1251};
1252
1254 public:
1255 explicit LogicalNot(Stream stream) : UnaryPrimitive(stream) {}
1256
1257 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1258 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1259
1265
1266 private:
1267 void eval(const std::vector<array>& inputs, array& out);
1268};
1269
1271 public:
1272 explicit LogicalAnd(Stream stream) : UnaryPrimitive(stream) {}
1273
1274 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1275 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1276
1282
1283 private:
1284 void eval(const std::vector<array>& inputs, array& out);
1285};
1286
1288 public:
1289 explicit LogicalOr(Stream stream) : UnaryPrimitive(stream) {}
1290
1291 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1292 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1293
1299
1300 private:
1301 void eval(const std::vector<array>& inputs, array& out);
1302};
1303
1305 public:
1306 explicit LogAddExp(Stream stream) : UnaryPrimitive(stream) {}
1307
1308 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1309 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1310
1316
1317 private:
1318 void eval(const std::vector<array>& inputs, array& out);
1319};
1320
1321class Matmul : public UnaryPrimitive {
1322 public:
1323 explicit Matmul(Stream stream) : UnaryPrimitive(stream) {}
1324
1325 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1326 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1327
1328 std::vector<array> vjp(
1329 const std::vector<array>& primals,
1330 const std::vector<array>& cotangents,
1331 const std::vector<int>& argnums,
1332 const std::vector<array>& outputs) override;
1333
1337 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1338};
1339
1340class Maximum : public UnaryPrimitive {
1341 public:
1342 explicit Maximum(Stream stream) : UnaryPrimitive(stream) {}
1343
1344 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1345 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1346
1352
1353 private:
1354 void eval(const std::vector<array>& inputs, array& out);
1355};
1356
1357class Minimum : public UnaryPrimitive {
1358 public:
1359 explicit Minimum(Stream stream) : UnaryPrimitive(stream) {}
1360
1361 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1362 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1363
1369
1370 private:
1371 void eval(const std::vector<array>& inputs, array& out);
1372};
1373
1374class Multiply : public UnaryPrimitive {
1375 public:
1376 explicit Multiply(Stream stream) : UnaryPrimitive(stream) {}
1377
1378 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1379 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1380
1386
1387 private:
1388 void eval(const std::vector<array>& inputs, array& out);
1389};
1390
1391class Negative : public UnaryPrimitive {
1392 public:
1393 explicit Negative(Stream stream) : UnaryPrimitive(stream) {}
1394
1395 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1396 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1397
1403
1404 private:
1405 void eval(const std::vector<array>& inputs, array& out);
1406};
1407
1408class NotEqual : public UnaryPrimitive {
1409 public:
1410 explicit NotEqual(Stream stream) : UnaryPrimitive(stream) {}
1411
1412 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1413 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1414
1420
1421 private:
1422 void eval(const std::vector<array>& inputs, array& out);
1423};
1424
1426 public:
1428 Stream stream,
1429 std::vector<int> axes,
1430 bool inverted,
1431 Dtype dtype)
1432 : UnaryPrimitive(stream),
1433 axes_(std::move(axes)),
1434 inverted_(inverted),
1435 dtype_(dtype) {}
1436
1437 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1438 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1439
1442 bool is_equivalent(const Primitive& other) const override;
1443 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
1444 return {{}};
1445 }
1446
1447 private:
1448 std::vector<int> axes_;
1449 bool inverted_;
1450 Dtype dtype_;
1451
1452 void eval(const std::vector<array>& inputs, array& out);
1453};
1454
1455class Pad : public UnaryPrimitive {
1456 public:
1457 explicit Pad(
1458 Stream stream,
1459 const std::vector<int>& axes,
1460 const std::vector<int>& low_pad_size,
1461 const std::vector<int>& high_pad_size)
1462 : UnaryPrimitive(stream),
1463 axes_(axes),
1464 low_pad_size_(low_pad_size),
1465 high_pad_size_(high_pad_size) {}
1466
1467 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1468 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1469
1473 bool is_equivalent(const Primitive& other) const override;
1474
1475 private:
1476 std::vector<int> axes_;
1477 std::vector<int> low_pad_size_;
1478 std::vector<int> high_pad_size_;
1479
1480 void eval(const std::vector<array>& inputs, array& out);
1481};
1482
1484 public:
1485 explicit Partition(Stream stream, int kth, int axis)
1486 : UnaryPrimitive(stream), kth_(kth), axis_(axis) {}
1487
1488 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1489 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1490
1495 bool is_equivalent(const Primitive& other) const override;
1496
1497 private:
1498 int kth_;
1499 int axis_;
1500
1501 void eval(const std::vector<array>& inputs, array& out);
1502};
1503
1504class Power : public UnaryPrimitive {
1505 public:
1506 explicit Power(Stream stream) : UnaryPrimitive(stream) {}
1507
1508 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1509 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1510
1516
1517 private:
1518 void eval(const std::vector<array>& inputs, array& out);
1519};
1520
1522 public:
1524 Stream stream,
1525 int group_size,
1526 int bits,
1527 bool transpose)
1528 : UnaryPrimitive(stream),
1529 group_size_(group_size),
1530 bits_(bits),
1531 transpose_(transpose) {}
1532
1533 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1534 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1535
1539 bool is_equivalent(const Primitive& other) const override;
1540 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1541
1542 private:
1543 int group_size_;
1544 int bits_;
1545 bool transpose_;
1546
1547 void eval(const std::vector<array>& inputs, array& out);
1548};
1549
1551 public:
1552 explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose)
1553 : UnaryPrimitive(stream),
1554 group_size_(group_size),
1555 bits_(bits),
1556 transpose_(transpose) {}
1557
1558 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1559 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1560
1564 bool is_equivalent(const Primitive& other) const override;
1565
1566 private:
1567 int group_size_;
1568 int bits_;
1569 bool transpose_;
1570
1571 void eval(const std::vector<array>& inputs, array& out);
1572};
1573
1575 public:
1576 explicit RandomBits(Stream stream, const Shape& shape, int width)
1577 : UnaryPrimitive(stream), shape_(shape), width_(width) {}
1578
1579 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1580 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1581
1584 bool is_equivalent(const Primitive& other) const override;
1585
1586 private:
1587 Shape shape_;
1588 int width_;
1589
1590 void eval(const std::vector<array>& inputs, array& out);
1591};
1592
1593class Real : public UnaryPrimitive {
1594 public:
1595 explicit Real(Stream stream) : UnaryPrimitive(stream) {}
1596
1597 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1598 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1599
1605};
1606
1607class Reshape : public UnaryPrimitive {
1608 public:
1609 explicit Reshape(Stream stream, const Shape& shape)
1610 : UnaryPrimitive(stream), shape_(shape) {}
1611
1612 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1613 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1614
1618 bool is_equivalent(const Primitive& other) const override;
1619
1620 private:
1621 Shape shape_;
1622
1623 void eval(const std::vector<array>& inputs, array& out);
1624
1625 static std::pair<bool, Strides> prepare_reshape(
1626 const array& in,
1627 const array& out);
1628 static void shared_buffer_reshape(
1629 const array& in,
1630 const Strides& out_strides,
1631 array& out);
1632};
1633
1634class Reduce : public UnaryPrimitive {
1635 public:
1636 enum ReduceType { And, Or, Sum, Prod, Min, Max };
1637
1638 explicit Reduce(
1639 Stream stream,
1640 ReduceType reduce_type,
1641 const std::vector<int>& axes)
1642 : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}
1643
1644 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1645 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1646
1648
1649 std::vector<array> vjp(
1650 const std::vector<array>& primals,
1651 const std::vector<array>& cotangents,
1652 const std::vector<int>& argnums,
1653 const std::vector<array>& outputs) override;
1654
1655 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1656
1657 void print(std::ostream& os) override {
1658 switch (reduce_type_) {
1659 case And:
1660 os << "And";
1661 break;
1662 case Or:
1663 os << "Or";
1664 break;
1665 case Sum:
1666 os << "Sum";
1667 break;
1668 case Prod:
1669 os << "Prod";
1670 break;
1671 case Min:
1672 os << "Min";
1673 break;
1674 case Max:
1675 os << "Max";
1676 break;
1677 }
1678 }
1679 bool is_equivalent(const Primitive& other) const override;
1680
1681 private:
1682 ReduceType reduce_type_;
1683 std::vector<int> axes_;
1684
1685 void eval(const std::vector<array>& inputs, array& out);
1686};
1687
1688class Round : public UnaryPrimitive {
1689 public:
1690 explicit Round(Stream stream) : UnaryPrimitive(stream) {}
1691
1692 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1693 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1694
1700
1701 private:
1702 void eval(const std::vector<array>& inputs, array& out);
1703};
1704
1705class Scan : public UnaryPrimitive {
1706 public:
1708
1709 explicit Scan(
1710 Stream stream,
1711 ReduceType reduce_type,
1712 int axis,
1713 bool reverse,
1714 bool inclusive)
1715 : UnaryPrimitive(stream),
1716 reduce_type_(reduce_type),
1717 axis_(axis),
1718 reverse_(reverse),
1719 inclusive_(inclusive) {}
1720
1721 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1722 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1723
1726
1727 void print(std::ostream& os) override {
1728 os << "Cum";
1729 switch (reduce_type_) {
1730 case Sum:
1731 os << "Sum";
1732 break;
1733 case Prod:
1734 os << "Prod";
1735 break;
1736 case Min:
1737 os << "Min";
1738 break;
1739 case Max:
1740 os << "Max";
1741 break;
1742 }
1743 }
1744 bool is_equivalent(const Primitive& other) const override;
1745
1746 private:
1747 ReduceType reduce_type_;
1748 int axis_;
1749 bool reverse_;
1750 bool inclusive_;
1751
1752 void eval(const std::vector<array>& inputs, array& out);
1753};
1754
1755class Scatter : public UnaryPrimitive {
1756 public:
1758
1759 explicit Scatter(
1760 Stream stream,
1761 ReduceType reduce_type,
1762 const std::vector<int>& axes)
1763 : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}
1764
1765 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1766 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1767
1770
1771 void print(std::ostream& os) override {
1772 os << "Scatter";
1773 switch (reduce_type_) {
1774 case Sum:
1775 os << " Sum";
1776 break;
1777 case Prod:
1778 os << " Prod";
1779 break;
1780 case Min:
1781 os << " Min";
1782 break;
1783 case Max:
1784 os << " Max";
1785 break;
1786 case None:
1787 break;
1788 }
1789 }
1790 bool is_equivalent(const Primitive& other) const override;
1791
1792 private:
1793 void eval(const std::vector<array>& inputs, array& out);
1794 ReduceType reduce_type_;
1795 std::vector<int> axes_;
1796};
1797
1798class Sigmoid : public UnaryPrimitive {
1799 public:
1800 explicit Sigmoid(Stream stream) : UnaryPrimitive(stream) {}
1801
1802 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1803 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1804
1810
1811 private:
1812 void eval(const std::vector<array>& inputs, array& out);
1813};
1814
1815class Sign : public UnaryPrimitive {
1816 public:
1817 explicit Sign(Stream stream) : UnaryPrimitive(stream) {}
1818
1819 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1820 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1821
1827
1828 private:
1829 void eval(const std::vector<array>& inputs, array& out);
1830};
1831
1832class Sin : public UnaryPrimitive {
1833 public:
1834 explicit Sin(Stream stream) : UnaryPrimitive(stream) {}
1835
1836 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1837 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1838
1844
1845 private:
1846 void eval(const std::vector<array>& inputs, array& out);
1847};
1848
1849class Sinh : public UnaryPrimitive {
1850 public:
1851 explicit Sinh(Stream stream) : UnaryPrimitive(stream) {}
1852
1853 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1854 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1855
1861
1862 private:
1863 void eval(const std::vector<array>& inputs, array& out);
1864};
1865
1866class Slice : public UnaryPrimitive {
1867 public:
1868 explicit Slice(
1869 Stream stream,
1870 const std::vector<int>& start_indices,
1871 const std::vector<int>& end_indices,
1872 const std::vector<int>& strides)
1873 : UnaryPrimitive(stream),
1874 start_indices_(start_indices),
1875 end_indices_(end_indices),
1876 strides_(strides) {}
1877
1878 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1879 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1880
1884 bool is_equivalent(const Primitive& other) const override;
1885
1886 private:
1887 std::vector<int> start_indices_;
1888 std::vector<int> end_indices_;
1889 std::vector<int> strides_;
1890
1891 void eval(const std::vector<array>& inputs, array& out);
1892};
1893
1895 public:
1896 explicit SliceUpdate(
1897 Stream stream,
1898 const std::vector<int>& start_indices,
1899 const std::vector<int>& end_indices,
1900 const std::vector<int>& strides)
1901 : UnaryPrimitive(stream),
1902 start_indices_(start_indices),
1903 end_indices_(end_indices),
1904 strides_(strides) {}
1905
1906 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1907 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1908
1912 bool is_equivalent(const Primitive& other) const override;
1913
1914 private:
1915 std::vector<int> start_indices_;
1916 std::vector<int> end_indices_;
1917 std::vector<int> strides_;
1918
1919 void eval(const std::vector<array>& inputs, array& out);
1920
1921 std::tuple<int64_t, std::vector<int64_t>> prepare_slice(const array& in);
1922};
1923
1924class Softmax : public UnaryPrimitive {
1925 public:
1926 explicit Softmax(Stream stream, bool precise)
1927 : UnaryPrimitive(stream), precise_(precise) {}
1928
1929 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1930 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1931
1936
1937 bool is_equivalent(const Primitive& other) const override;
1938
1939 private:
1940 void eval(const std::vector<array>& inputs, array& out);
1941 bool precise_;
1942};
1943
1944class Sort : public UnaryPrimitive {
1945 public:
1946 explicit Sort(Stream stream, int axis)
1947 : UnaryPrimitive(stream), axis_(axis) {}
1948
1949 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1950 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1951
1956 bool is_equivalent(const Primitive& other) const override;
1957
1958 private:
1959 int axis_;
1960
1961 void eval(const std::vector<array>& inputs, array& out);
1962};
1963
1964class Split : public Primitive {
1965 public:
1966 explicit Split(Stream stream, const std::vector<int>& indices, int axis)
1967 : Primitive(stream), indices_(indices), axis_(axis) {}
1968
1969 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
1970 override;
1971 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
1972 override;
1973
1977 bool is_equivalent(const Primitive& other) const override;
1978
1979 private:
1980 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
1981
1982 std::vector<int> indices_;
1983 int axis_;
1984};
1985
1986class Square : public UnaryPrimitive {
1987 public:
1988 explicit Square(Stream stream) : UnaryPrimitive(stream) {}
1989
1990 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1991 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1992
1998
1999 private:
2000 void eval(const std::vector<array>& inputs, array& out);
2001};
2002
2003class Sqrt : public UnaryPrimitive {
2004 public:
2005 explicit Sqrt(Stream stream, bool recip = false)
2006 : UnaryPrimitive(stream), recip_(recip) {}
2007
2008 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2009 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2010
2014 bool is_equivalent(const Primitive& other) const override;
2015
2016 void print(std::ostream& os) override {
2017 if (recip_) {
2018 os << "Rsqrt";
2019 } else {
2020 os << "Sqrt";
2021 }
2022 }
2023
2024 private:
2025 void eval(const std::vector<array>& inputs, array& out);
2026 bool recip_;
2027};
2028
2030 public:
2031 explicit StopGradient(Stream stream) : UnaryPrimitive(stream) {}
2032
2033 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2034 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2035
2040
2041 private:
2042 void eval(const std::vector<array>& inputs, array& out);
2043};
2044
2045class Subtract : public UnaryPrimitive {
2046 public:
2047 explicit Subtract(Stream stream) : UnaryPrimitive(stream) {}
2048
2049 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2050 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2051
2057
2058 private:
2059 void eval(const std::vector<array>& inputs, array& out);
2060};
2061
2062class Tan : public UnaryPrimitive {
2063 public:
2064 explicit Tan(Stream stream) : UnaryPrimitive(stream) {}
2065
2066 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2067 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2068
2074
2075 private:
2076 void eval(const std::vector<array>& inputs, array& out);
2077};
2078
2079class Tanh : public UnaryPrimitive {
2080 public:
2081 explicit Tanh(Stream stream) : UnaryPrimitive(stream) {}
2082
2083 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2084 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2085
2091
2092 private:
2093 void eval(const std::vector<array>& inputs, array& out);
2094};
2095
2096class Uniform : public UnaryPrimitive {
2097 public:
2098 explicit Uniform(Stream stream) : UnaryPrimitive(stream) {}
2099
2100 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2101 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2102
2106
2107 private:
2108 void eval(const std::vector<array>& inputs, array& out);
2109};
2110
2111class View : public UnaryPrimitive {
2112 public:
2113 explicit View(Stream stream, Dtype dtype)
2114 : UnaryPrimitive(stream), dtype_(dtype) {}
2115
2116 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2117 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2118
2120 void print(std::ostream& os) override;
2121 bool is_equivalent(const Primitive& other) const override;
2122
2123 private:
2124 Dtype dtype_;
2125};
2126
2128 public:
2129 explicit Transpose(Stream stream, const std::vector<int>& axes)
2130 : UnaryPrimitive(stream), axes_(axes) {}
2131
2132 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2133 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2134
2138 bool is_equivalent(const Primitive& other) const override;
2139 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
2140
2141 private:
2142 std::vector<int> axes_;
2143
2144 void eval(const std::vector<array>& inputs, array& out);
2145};
2146
2147/* QR Factorization primitive. */
2148class QRF : public Primitive {
2149 public:
2150 explicit QRF(Stream stream) : Primitive(stream) {}
2151
2152 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2153 override;
2154 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2155 override;
2156
2158
2159 private:
2160 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
2161};
2162
2163/* SVD primitive. */
2164class SVD : public Primitive {
2165 public:
2166 explicit SVD(Stream stream) : Primitive(stream) {}
2167
2168 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2169 override;
2170 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2171 override;
2172
2175
2176 private:
2177 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
2178};
2179
2180/* Matrix inversion primitive. */
2181class Inverse : public UnaryPrimitive {
2182 public:
2183 explicit Inverse(Stream stream, bool tri, bool upper)
2184 : UnaryPrimitive(stream), tri_(tri), upper_(upper) {}
2185
2186 void eval_cpu(const std::vector<array>& inputs, array& output) override;
2187 void eval_gpu(const std::vector<array>& inputs, array& output) override;
2188
2191
2192 private:
2193 void eval(const std::vector<array>& inputs, array& output);
2194 bool tri_;
2195 bool upper_;
2196};
2197
2198class Cholesky : public UnaryPrimitive {
2199 public:
2200 explicit Cholesky(Stream stream, bool upper)
2201 : UnaryPrimitive(stream), upper_(upper) {}
2202
2203 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2204 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2205
2208
2209 private:
2210 void eval(const std::vector<array>& inputs, array& output);
2211 bool upper_;
2212};
2213
2214class Eigh : public Primitive {
2215 public:
2216 explicit Eigh(Stream stream, std::string uplo, bool compute_eigenvectors)
2217 : Primitive(stream),
2218 uplo_(std::move(uplo)),
2219 compute_eigenvectors_(compute_eigenvectors) {}
2220
2221 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2222 override;
2223 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2224 override;
2225
2228
2229 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
2230
2231 bool is_equivalent(const Primitive& other) const override;
2232
2233 private:
2234 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
2235 std::string uplo_;
2236 bool compute_eigenvectors_;
2237};
2238
2239} // namespace mlx::core
Definition primitives.h:154
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Abs(Stream stream)
Definition primitives.h:156
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:163
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:164
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:165
Definition primitives.h:171
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Add(Stream stream)
Definition primitives.h:173
Definition primitives.h:188
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
AddMM(Stream stream, float alpha, float beta)
Definition primitives.h:190
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
Definition primitives.h:212
Arange(Stream stream, double start, double stop, double step)
Definition primitives.h:214
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:232
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCos(Stream stream)
Definition primitives.h:234
Definition primitives.h:249
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCosh(Stream stream)
Definition primitives.h:251
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:266
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcSin(Stream stream)
Definition primitives.h:268
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:283
ArcSinh(Stream stream)
Definition primitives.h:285
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:317
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTan2(Stream stream)
Definition primitives.h:319
Definition primitives.h:300
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcTan(Stream stream)
Definition primitives.h:302
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:334
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTanh(Stream stream)
Definition primitives.h:336
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:351
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArgPartition(Stream stream, int kth, int axis)
Definition primitives.h:353
Definition primitives.h:372
ReduceType
Definition primitives.h:374
@ ArgMin
Definition primitives.h:375
@ ArgMax
Definition primitives.h:376
ArgReduce(Stream stream, ReduceType reduce_type, int axis)
Definition primitives.h:379
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:398
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArgSort(Stream stream, int axis)
Definition primitives.h:400
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:437
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
AsStrided(Stream stream, Shape shape, Strides strides, size_t offset)
Definition primitives.h:439
Definition primitives.h:417
void eval_gpu(const std::vector< array > &inputs, array &out) override
AsType(Stream stream, Dtype dtype)
Definition primitives.h:419
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:460
BitwiseBinary(Stream stream, Op op)
Definition primitives.h:464
void eval_cpu(const std::vector< array > &inputs, array &out) override
Op
Definition primitives.h:462
@ And
Definition primitives.h:462
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:480
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
BlockMaskedMM(Stream stream, int block_size)
Definition primitives.h:482
Definition primitives.h:523
Broadcast(Stream stream, const Shape &shape)
Definition primitives.h:525
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:542
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Ceil(Stream stream)
Definition primitives.h:544
Definition primitives.h:2198
void eval_cpu(const std::vector< array > &inputs, array &out) override
Cholesky(Stream stream, bool upper)
Definition primitives.h:2200
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:559
Compiled(Stream stream, std::vector< array > inputs, std::vector< array > outputs, std::vector< array > tape, std::unordered_set< uintptr_t > constant_ids)
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:601
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Concatenate(Stream stream, int axis)
Definition primitives.h:603
Definition primitives.h:621
Conjugate(Stream stream)
Definition primitives.h:623
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:637
Contiguous(Stream stream, bool allow_col_major)
Definition primitives.h:639
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:656
void eval_gpu(const std::vector< array > &inputs, array &out) override
Convolution(Stream stream, const std::vector< int > &kernel_strides, const std::vector< int > &padding, const std::vector< int > &kernel_dilation, const std::vector< int > &input_dilation, const int groups=1, const bool flip=false)
Definition primitives.h:658
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
Definition primitives.h:697
void eval_gpu(const std::vector< array > &inputs, array &out) override
Copy(Stream stream)
Definition primitives.h:699
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:714
void eval_cpu(const std::vector< array > &inputs, array &out) override
Cos(Stream stream)
Definition primitives.h:716
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:731
void eval_gpu(const std::vector< array > &inputs, array &out) override
Cosh(Stream stream)
Definition primitives.h:733
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:748
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
CustomTransforms(Stream stream, int num_outputs, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> vjp, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< int > &)> jvp, std::function< std::pair< std::vector< array >, std::vector< int > >(const std::vector< array > &, const std::vector< int > &)> vmap)
Definition primitives.h:750
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:800
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotan, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Depends(Stream stream)
Definition primitives.h:802
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:838
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
DivMod(Stream stream)
Definition primitives.h:840
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:821
Divide(Stream stream)
Definition primitives.h:823
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2214
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Eigh(Stream stream, std::string uplo, bool compute_eigenvectors)
Definition primitives.h:2216
Definition primitives.h:893
Equal(Stream stream, bool equal_nan=false)
Definition primitives.h:895
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:919
Erf(Stream stream)
Definition primitives.h:921
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:936
void eval_gpu(const std::vector< array > &inputs, array &out) override
ErfInv(Stream stream)
Definition primitives.h:938
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:953
Exp(Stream stream)
Definition primitives.h:955
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:970
Expm1(Stream stream)
Definition primitives.h:972
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:986
FFT(Stream stream, const std::vector< size_t > &axes, bool inverse, bool real)
Definition primitives.h:988
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1012
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Floor(Stream stream)
Definition primitives.h:1014
Definition primitives.h:1029
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Full(Stream stream)
Definition primitives.h:1031
Definition primitives.h:1045
Gather(Stream stream, const std::vector< int > &axes, const std::vector< int > &slice_sizes)
Definition primitives.h:1047
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:503
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, array &out) override
GatherMM(Stream stream)
Definition primitives.h:505
Definition primitives.h:1550
GatherQMM(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1552
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1085
void eval_cpu(const std::vector< array > &inputs, array &out) override
GreaterEqual(Stream stream)
Definition primitives.h:1087
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1068
Greater(Stream stream)
Definition primitives.h:1070
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1102
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Hadamard(Stream stream, float scale)
Definition primitives.h:1104
Definition primitives.h:1123
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Imag(Stream stream)
Definition primitives.h:1125
Definition primitives.h:2181
void eval_gpu(const std::vector< array > &inputs, array &output) override
Inverse(Stream stream, bool tri, bool upper)
Definition primitives.h:2183
void eval_cpu(const std::vector< array > &inputs, array &output) override
Definition primitives.h:1154
LessEqual(Stream stream)
Definition primitives.h:1156
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1137
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Less(Stream stream)
Definition primitives.h:1139
Definition primitives.h:1171
void eval_gpu(const std::vector< array > &inputs, array &out) override
Load(Stream stream, std::shared_ptr< io::Reader > reader, size_t offset, bool swap_endianness=false)
Definition primitives.h:1173
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1237
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Log1p(Stream stream)
Definition primitives.h:1239
Definition primitives.h:1304
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogAddExp(Stream stream)
Definition primitives.h:1306
Definition primitives.h:1203
Base
Definition primitives.h:1205
Log(Stream stream, Base base)
Definition primitives.h:1207
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1270
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalAnd(Stream stream)
Definition primitives.h:1272
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1253
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalNot(Stream stream)
Definition primitives.h:1255
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1287
void eval_cpu(const std::vector< array > &inputs, array &out) override
LogicalOr(Stream stream)
Definition primitives.h:1289
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1321
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, array &out) override
Matmul(Stream stream)
Definition primitives.h:1323
Definition primitives.h:1340
Maximum(Stream stream)
Definition primitives.h:1342
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1357
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Minimum(Stream stream)
Definition primitives.h:1359
Definition primitives.h:1374
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Multiply(Stream stream)
Definition primitives.h:1376
Definition primitives.h:1391
void eval_gpu(const std::vector< array > &inputs, array &out) override
Negative(Stream stream)
Definition primitives.h:1393
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1408
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
NotEqual(Stream stream)
Definition primitives.h:1410
Definition primitives.h:1425
void eval_gpu(const std::vector< array > &inputs, array &out) override
NumberOfElements(Stream stream, std::vector< int > axes, bool inverted, Dtype dtype)
Definition primitives.h:1427
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1455
void eval_cpu(const std::vector< array > &inputs, array &out) override
Pad(Stream stream, const std::vector< int > &axes, const std::vector< int > &low_pad_size, const std::vector< int > &high_pad_size)
Definition primitives.h:1457
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1483
void eval_cpu(const std::vector< array > &inputs, array &out) override
Partition(Stream stream, int kth, int axis)
Definition primitives.h:1485
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1504
void eval_cpu(const std::vector< array > &inputs, array &out) override
Power(Stream stream)
Definition primitives.h:1506
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:48
virtual void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
virtual std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
The vector-Jacobian product.
virtual ~Primitive()=default
Primitive(const Primitive &other)=delete
Primitive(Primitive &&other)=delete
const Stream & stream()
The stream the primitive will run on.
Definition primitives.h:58
Primitive & operator=(Primitive &&other)=delete
virtual bool is_equivalent(const Primitive &other) const
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:107
Primitive & operator=(const Primitive &other)=delete
const Device & device()
The device the primitive will run on.
Definition primitives.h:53
virtual std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
The Jacobian-vector product.
virtual std::vector< Shape > output_shapes(const std::vector< array > &inputs)
Get the output shapes of the primitive.
virtual std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes)
The primitive must know how to vectorize itself across the given axes.
virtual void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
virtual void print(std::ostream &os)=0
Print the primitive.
Primitive(Stream stream)
Definition primitives.h:50
Definition primitives.h:2148
QRF(Stream stream)
Definition primitives.h:2150
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:1521
void eval_gpu(const std::vector< array > &inputs, array &out) override
QuantizedMatmul(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1523
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1574
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
RandomBits(Stream stream, const Shape &shape, int width)
Definition primitives.h:1576
Definition primitives.h:1593
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Real(Stream stream)
Definition primitives.h:1595
Definition primitives.h:1634
Reduce(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1638
ReduceType
Definition primitives.h:1636
@ And
Definition primitives.h:1636
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:876
Remainder(Stream stream)
Definition primitives.h:878
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1607
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Reshape(Stream stream, const Shape &shape)
Definition primitives.h:1609
Definition primitives.h:1688
Round(Stream stream)
Definition primitives.h:1690
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2164
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
SVD(Stream stream)
Definition primitives.h:2166
Definition primitives.h:1705
void eval_cpu(const std::vector< array > &inputs, array &out) override
ReduceType
Definition primitives.h:1707
@ Max
Definition primitives.h:1707
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Scan(Stream stream, ReduceType reduce_type, int axis, bool reverse, bool inclusive)
Definition primitives.h:1709
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1755
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
ReduceType
Definition primitives.h:1757
@ Max
Definition primitives.h:1757
void eval_cpu(const std::vector< array > &inputs, array &out) override
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1771
void eval_gpu(const std::vector< array > &inputs, array &out) override
Scatter(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1759
Definition primitives.h:859
void eval_gpu(const std::vector< array > &inputs, array &out) override
Select(Stream stream)
Definition primitives.h:861
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1798
Sigmoid(Stream stream)
Definition primitives.h:1800
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1815
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Sign(Stream stream)
Definition primitives.h:1817
Definition primitives.h:1832
Sin(Stream stream)
Definition primitives.h:1834
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1849
Sinh(Stream stream)
Definition primitives.h:1851
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1866
void eval_cpu(const std::vector< array > &inputs, array &out) override
Slice(Stream stream, const std::vector< int > &start_indices, const std::vector< int > &end_indices, const std::vector< int > &strides)
Definition primitives.h:1868
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1894
SliceUpdate(Stream stream, const std::vector< int > &start_indices, const std::vector< int > &end_indices, const std::vector< int > &strides)
Definition primitives.h:1896
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1924
void eval_gpu(const std::vector< array > &inputs, array &out) override
Softmax(Stream stream, bool precise)
Definition primitives.h:1926
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1944
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sort(Stream stream, int axis)
Definition primitives.h:1946
Definition primitives.h:1964
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Split(Stream stream, const std::vector< int > &indices, int axis)
Definition primitives.h:1966
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:2003
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sqrt(Stream stream, bool recip=false)
Definition primitives.h:2005
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1986
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Square(Stream stream)
Definition primitives.h:1988
Definition primitives.h:2029
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
StopGradient(Stream stream)
Definition primitives.h:2031
Definition primitives.h:2045
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Subtract(Stream stream)
Definition primitives.h:2047
Definition primitives.h:2062
Tan(Stream stream)
Definition primitives.h:2064
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2079
void eval_gpu(const std::vector< array > &inputs, array &out) override
Tanh(Stream stream)
Definition primitives.h:2081
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2127
Transpose(Stream stream, const std::vector< int > &axes)
Definition primitives.h:2129
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:126
UnaryPrimitive & operator=(const UnaryPrimitive &other)=delete
UnaryPrimitive(Stream stream)
An abstract base class for a primitive with a single output.
Definition primitives.h:131
virtual void eval_gpu(const std::vector< array > &inputs, array &output)=0
UnaryPrimitive(UnaryPrimitive &&other)=delete
virtual void eval_cpu(const std::vector< array > &inputs, array &output)=0
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:141
UnaryPrimitive(const UnaryPrimitive &other)=delete
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:136
UnaryPrimitive & operator=(UnaryPrimitive &&other)=delete
virtual ~UnaryPrimitive()=default
Definition primitives.h:2096
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Uniform(Stream stream)
Definition primitives.h:2098
Definition primitives.h:2111
void eval_cpu(const std::vector< array > &inputs, array &out) override
View(Stream stream, Dtype dtype)
Definition primitives.h:2113
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition array.h:23
Op op
Definition binary.h:129
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
array tri(int n, int m, int k, Dtype type, StreamOrDevice s={})
array transpose(const array &a, std::vector< int > axes, StreamOrDevice s={})
Permutes the dimensions according to the given axes.
array real(const array &a, StreamOrDevice s={})
Definition allocator.h:7
std::pair< std::vector< array >, std::vector< array > > jvp(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &tangents)
Computes the output and Jacobian-vector product (JVP) of a function.
std::pair< std::vector< array >, std::vector< array > > vjp(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &cotangents)
Computes the output and vector-Jacobian product (VJP) of a function.
void eval(std::vector< array > outputs)
std::vector< int32_t > Shape
Definition array.h:20
std::function< array(const array &)> vmap(const std::function< array(const array &)> &fun, int in_axis=0, int out_axis=0)
Automatically vectorize a unary function over the requested axes.
std::vector< size_t > Strides
Definition array.h:21
#define DEFINE_DEFAULT_IS_EQUIVALENT()
Definition primitives.h:34
#define DEFINE_PRINT(PRIMITIVE)
Definition primitives.h:29
#define DEFINE_INPUT_OUTPUT_SHAPE()
Definition primitives.h:39
#define DEFINE_GRADS()
Definition primitives.h:17
#define DEFINE_VMAP()
Definition primitives.h:12
Definition ops.h:37
Definition binary_ops.h:270
Definition ops.h:185
Definition ops.h:163
Definition ops.h:29
Definition ops.h:78
Definition ops.h:141
Definition binary_ops.h:277
Definition ops.h:119
Definition device.h:7
Definition dtype.h:13
Definition stream.h:9
Device device
Definition stream.h:11