MLX
 
Loading...
Searching...
No Matches
primitives.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <unordered_set>
6
7#include "mlx/array.h"
8#include "mlx/device.h"
9#include "mlx/io/load.h"
10#include "mlx/stream.h"
11
12#define DEFINE_VMAP() \
13 virtual std::pair<std::vector<array>, std::vector<int>> vmap( \
14 const std::vector<array>& inputs, const std::vector<int>& axes) \
15 override;
16
17#define DEFINE_GRADS() \
18 std::vector<array> jvp( \
19 const std::vector<array>& primals, \
20 const std::vector<array>& tangents, \
21 const std::vector<int>& argnums) override; \
22 \
23 std::vector<array> vjp( \
24 const std::vector<array>& primals, \
25 const std::vector<array>& cotangents, \
26 const std::vector<int>& argnums, \
27 const std::vector<array>& outputs) override;
28
29#define DEFINE_PRINT(PRIMITIVE) \
30 void print(std::ostream& os) override { \
31 os << #PRIMITIVE; \
32 }
33
34#define DEFINE_DEFAULT_IS_EQUIVALENT() \
35 bool is_equivalent(const Primitive& other) const override { \
36 return true; \
37 }
38
39#define DEFINE_INPUT_OUTPUT_SHAPE() \
40 std::vector<Shape> output_shapes(const std::vector<array>& inputs) \
41 override { \
42 return {inputs[0].shape()}; \
43 }
44
45namespace mlx::core {
46
47// Abstract base class
48class Primitive {
49 public:
50 explicit Primitive(Stream stream) : stream_(stream) {}
51
53 const Device& device() {
54 return stream().device;
55 }
56
58 const Stream& stream() {
59 return stream_;
60 }
61
69 virtual void eval_cpu(
70 const std::vector<array>& inputs,
71 std::vector<array>& outputs) = 0;
72 virtual void eval_gpu(
73 const std::vector<array>& inputs,
74 std::vector<array>& outputs) = 0;
75
79 virtual std::vector<array> jvp(
80 const std::vector<array>& primals,
81 const std::vector<array>& tangents,
82 const std::vector<int>& argnums);
83
87 virtual std::vector<array> vjp(
88 const std::vector<array>& primals,
89 const std::vector<array>& cotangents,
90 const std::vector<int>& argnums,
91 const std::vector<array>& outputs);
92
99 virtual std::pair<std::vector<array>, std::vector<int>> vmap(
100 const std::vector<array>& inputs,
101 const std::vector<int>& axes);
102
104 virtual void print(std::ostream& os) = 0;
105
107 virtual bool is_equivalent(const Primitive& other) const {
108 return false;
109 }
110
113 virtual std::vector<Shape> output_shapes(const std::vector<array>& inputs);
114
115 virtual ~Primitive() = default;
116 Primitive(const Primitive& other) = delete;
117 Primitive(Primitive&& other) = delete;
118 Primitive& operator=(const Primitive& other) = delete;
119 Primitive& operator=(Primitive&& other) = delete;
120
121 private:
122 // Every primitive stores the stream it should run in
123 Stream stream_;
124};
125
126class UnaryPrimitive : public Primitive {
130 public:
132
133 virtual void eval_cpu(const std::vector<array>& inputs, array& output) = 0;
134 virtual void eval_gpu(const std::vector<array>& inputs, array& output) = 0;
135
136 inline void eval_cpu(
137 const std::vector<array>& inputs,
138 std::vector<array>& outputs) override {
139 eval_cpu(inputs, outputs[0]);
140 }
141 inline void eval_gpu(
142 const std::vector<array>& inputs,
143 std::vector<array>& outputs) override {
144 eval_gpu(inputs, outputs[0]);
145 }
146
147 virtual ~UnaryPrimitive() = default;
148 UnaryPrimitive(const UnaryPrimitive& other) = delete;
150 UnaryPrimitive& operator=(const UnaryPrimitive& other) = delete;
152};
153
154class Abs : public UnaryPrimitive {
155 public:
157
158 void eval_cpu(const std::vector<array>& inputs, array& out) override;
159 void eval_gpu(const std::vector<array>& inputs, array& out) override;
160
166
167 private:
168 void eval(const std::vector<array>& inputs, array& out);
169};
170
171class Add : public UnaryPrimitive {
172 public:
174
175 void eval_cpu(const std::vector<array>& inputs, array& out) override;
176 void eval_gpu(const std::vector<array>& inputs, array& out) override;
177
183
184 private:
185 void eval(const std::vector<array>& inputs, array& out);
186};
187
188class AddMM : public UnaryPrimitive {
189 public:
190 explicit AddMM(Stream stream, float alpha, float beta)
191 : UnaryPrimitive(stream), alpha_(alpha), beta_(beta) {}
192
193 void eval_cpu(const std::vector<array>& inputs, array& out) override;
194 void eval_gpu(const std::vector<array>& inputs, array& out) override;
195
196 std::vector<array> vjp(
197 const std::vector<array>& primals,
198 const std::vector<array>& cotangents,
199 const std::vector<int>& argnums,
200 const std::vector<array>& outputs) override;
201
204
205 bool is_equivalent(const Primitive& other) const override;
206 std::pair<float, float> state() const {
207 return {alpha_, beta_};
208 };
209
210 private:
211 const float alpha_;
212 const float beta_;
213};
214
215class Arange : public UnaryPrimitive {
216 public:
217 explicit Arange(Stream stream, double start, double stop, double step)
218 : UnaryPrimitive(stream), start_(start), stop_(stop), step_(step) {}
219
220 void eval_cpu(const std::vector<array>& inputs, array& out) override;
221 void eval_gpu(const std::vector<array>& inputs, array& out) override;
222
224 bool is_equivalent(const Primitive& other) const override;
225 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
226 std::tuple<double, double, double> state() const {
227 return {start_, stop_, step_};
228 };
229
230 private:
231 double start_;
232 double stop_;
233 double step_;
234
235 void eval(const std::vector<array>& inputs, array& out);
236};
237
238class ArcCos : public UnaryPrimitive {
239 public:
241
242 void eval_cpu(const std::vector<array>& inputs, array& out) override;
243 void eval_gpu(const std::vector<array>& inputs, array& out) override;
244
250
251 private:
252 void eval(const std::vector<array>& inputs, array& out);
253};
254
255class ArcCosh : public UnaryPrimitive {
256 public:
258
259 void eval_cpu(const std::vector<array>& inputs, array& out) override;
260 void eval_gpu(const std::vector<array>& inputs, array& out) override;
261
267
268 private:
269 void eval(const std::vector<array>& inputs, array& out);
270};
271
272class ArcSin : public UnaryPrimitive {
273 public:
275
276 void eval_cpu(const std::vector<array>& inputs, array& out) override;
277 void eval_gpu(const std::vector<array>& inputs, array& out) override;
278
284
285 private:
286 void eval(const std::vector<array>& inputs, array& out);
287};
288
289class ArcSinh : public UnaryPrimitive {
290 public:
292
293 void eval_cpu(const std::vector<array>& inputs, array& out) override;
294 void eval_gpu(const std::vector<array>& inputs, array& out) override;
295
301
302 private:
303 void eval(const std::vector<array>& inputs, array& out);
304};
305
306class ArcTan : public UnaryPrimitive {
307 public:
309
310 void eval_cpu(const std::vector<array>& inputs, array& out) override;
311 void eval_gpu(const std::vector<array>& inputs, array& out) override;
312
318
319 private:
320 void eval(const std::vector<array>& inputs, array& out);
321};
322
323class ArcTan2 : public UnaryPrimitive {
324 public:
326
327 void eval_cpu(const std::vector<array>& inputs, array& out) override;
328 void eval_gpu(const std::vector<array>& inputs, array& out) override;
329
335
336 private:
337 void eval(const std::vector<array>& inputs, array& out);
338};
339
340class ArcTanh : public UnaryPrimitive {
341 public:
343
344 void eval_cpu(const std::vector<array>& inputs, array& out) override;
345 void eval_gpu(const std::vector<array>& inputs, array& out) override;
346
352
353 private:
354 void eval(const std::vector<array>& inputs, array& out);
355};
356
358 public:
359 explicit ArgPartition(Stream stream, int kth, int axis)
360 : UnaryPrimitive(stream), kth_(kth), axis_(axis) {}
361
362 void eval_cpu(const std::vector<array>& inputs, array& out) override;
363 void eval_gpu(const std::vector<array>& inputs, array& out) override;
364
369 bool is_equivalent(const Primitive& other) const override;
370 std::pair<int, int> state() const {
371 return {kth_, axis_};
372 };
373
374 private:
375 int kth_;
376 int axis_;
377
378 void eval(const std::vector<array>& inputs, array& out);
379};
380
381class ArgReduce : public UnaryPrimitive {
382 public:
387
388 explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis)
389 : UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {}
390
391 void eval_cpu(const std::vector<array>& inputs, array& out) override;
392 void eval_gpu(const std::vector<array>& inputs, array& out) override;
393
397 bool is_equivalent(const Primitive& other) const override;
398 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
399 std::pair<ReduceType, int> state() const {
400 return {reduce_type_, axis_};
401 };
402
403 private:
404 ReduceType reduce_type_;
405 int axis_;
406
407 void eval(const std::vector<array>& inputs, array& out);
408};
409
410class ArgSort : public UnaryPrimitive {
411 public:
412 explicit ArgSort(Stream stream, int axis)
413 : UnaryPrimitive(stream), axis_(axis) {}
414
415 void eval_cpu(const std::vector<array>& inputs, array& out) override;
416 void eval_gpu(const std::vector<array>& inputs, array& out) override;
417
421 bool is_equivalent(const Primitive& other) const override;
422 int state() const {
423 return axis_;
424 };
425
426 private:
427 int axis_;
428
429 void eval(const std::vector<array>& inputs, array& out);
430};
431
432class AsType : public UnaryPrimitive {
433 public:
434 explicit AsType(Stream stream, Dtype dtype)
435 : UnaryPrimitive(stream), dtype_(dtype) {}
436
437 void eval_cpu(const std::vector<array>& inputs, array& out) override;
438 void eval_gpu(const std::vector<array>& inputs, array& out) override;
439
444 bool is_equivalent(const Primitive& other) const override;
445 Dtype state() const {
446 return dtype_;
447 };
448
449 private:
450 Dtype dtype_;
451
452 void eval(const std::vector<array>& inputs, array& out);
453};
454
455class AsStrided : public UnaryPrimitive {
456 public:
457 explicit AsStrided(Stream stream, Shape shape, Strides strides, size_t offset)
459 shape_(std::move(shape)),
460 strides_(std::move(strides)),
461 offset_(offset) {}
462
463 void eval_cpu(const std::vector<array>& inputs, array& out) override;
464 void eval_gpu(const std::vector<array>& inputs, array& out) override;
465
468 bool is_equivalent(const Primitive& other) const override;
469 auto state() const {
470 return std::make_tuple(shape_, strides_, offset_);
471 }
472
473 private:
474 Shape shape_;
475 Strides strides_;
476 size_t offset_;
477
478 void eval(const std::vector<array>& inputs, array& out);
479};
480
482 public:
484
486 : UnaryPrimitive(stream), op_(op) {}
487
488 void eval_cpu(const std::vector<array>& inputs, array& out) override;
489 void eval_gpu(const std::vector<array>& inputs, array& out) override;
490
493 bool is_equivalent(const Primitive& other) const override;
494 void print(std::ostream& os) override;
496 auto state() const {
497 return op_;
498 }
499
500 private:
501 Op op_;
502};
503
505 public:
506 explicit BlockMaskedMM(Stream stream, int block_size)
507 : UnaryPrimitive(stream), block_size_(block_size) {}
508
509 void eval_cpu(const std::vector<array>& inputs, array& out) override;
510 void eval_gpu(const std::vector<array>& inputs, array& out) override;
511
512 std::vector<array> vjp(
513 const std::vector<array>& primals,
514 const std::vector<array>& cotangents,
515 const std::vector<int>& argnums,
516 const std::vector<array>& outputs) override;
517
519 bool is_equivalent(const Primitive& other) const override;
520 auto state() const {
521 return block_size_;
522 }
523
524 private:
525 int block_size_;
526
527 void eval(const std::vector<array>& inputs, array& out);
528};
529
530class GatherMM : public UnaryPrimitive {
531 public:
533
534 void eval_cpu(const std::vector<array>& inputs, array& out) override;
535 void eval_gpu(const std::vector<array>& inputs, array& out) override;
536
537 std::vector<array> vjp(
538 const std::vector<array>& primals,
539 const std::vector<array>& cotangents,
540 const std::vector<int>& argnums,
541 const std::vector<array>& outputs) override;
542
545
546 private:
547 void eval(const std::vector<array>& inputs, array& out);
548};
549
551 public:
552 explicit BroadcastAxes(Stream stream, std::vector<int> ignore_axes = {})
553 : UnaryPrimitive(stream), ignore_axes_(std::move(ignore_axes)) {}
554
555 void eval_cpu(const std::vector<array>& inputs, array& out) override;
556 void eval_gpu(const std::vector<array>& inputs, array& out) override;
557
561 bool is_equivalent(const Primitive& other) const override;
563 const std::vector<array>& inputs,
564 const std::vector<int>& ignore_axes);
565 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
566 auto state() const {
567 return ignore_axes_;
568 }
569
570 private:
571 void eval(const std::vector<array>& inputs, array& out);
572 std::vector<int> ignore_axes_;
573};
574
575class Broadcast : public UnaryPrimitive {
576 public:
577 explicit Broadcast(Stream stream, const Shape& shape)
578 : UnaryPrimitive(stream), shape_(shape) {}
579
580 void eval_cpu(const std::vector<array>& inputs, array& out) override;
581 void eval_gpu(const std::vector<array>& inputs, array& out) override;
582
586 static Shape output_shape(const std::vector<array>& inputs);
587 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
588 bool is_equivalent(const Primitive& other) const override;
589 std::vector<int> state() const {
590 return shape_;
591 };
592
593 private:
594 Shape shape_;
595
596 void eval(const std::vector<array>& inputs, array& out);
597};
598
599class Ceil : public UnaryPrimitive {
600 public:
602
603 void eval_cpu(const std::vector<array>& inputs, array& out) override;
604 void eval_gpu(const std::vector<array>& inputs, array& out) override;
605
611
612 private:
613 void eval(const std::vector<array>& inputs, array& out);
614};
615
616class Compiled : public Primitive {
617 public:
618 /*
619 * The inputs, outputs and tape are either tracers or constants.
620 * - The tape should not contain the inputs, but it should contain the
621 * outputs.
622 * - The tape should also have only one array per primitive for multi-output
623 * primitives.
624 * - The constant_ids contains ids of arrays in the input list that are safe
625 * to treat as scalar constants.
626 */
627 explicit Compiled(
629 std::vector<array> inputs,
630 std::vector<array> outputs,
631 std::vector<array> tape,
632 std::unordered_set<uintptr_t> constant_ids);
633
634 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
635 override;
636 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
637 override;
638
641 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
642 void print(std::ostream& os) override;
643 bool is_equivalent(const Primitive& other) const override;
644
645 std::string lib_name() const {
646 return kernel_lib_;
647 }
648
649 private:
650 const std::vector<array> inputs_;
651 const std::vector<array> outputs_;
652 const std::vector<array> tape_;
653 const std::unordered_set<uintptr_t> constant_ids_;
654
655 std::string kernel_lib_;
656};
657
659 public:
660 explicit Concatenate(Stream stream, int axis)
661 : UnaryPrimitive(stream), axis_(axis) {}
662
663 void eval_cpu(const std::vector<array>& inputs, array& out) override;
664 void eval_gpu(const std::vector<array>& inputs, array& out) override;
665
669 bool is_equivalent(const Primitive& other) const override;
670 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
671 auto state() const {
672 return axis_;
673 }
674
675 private:
676 int axis_;
677
678 void eval(const std::vector<array>& inputs, array& out);
679};
680
681class Conjugate : public UnaryPrimitive {
682 public:
684
685 void eval_cpu(const std::vector<array>& inputs, array& out) override;
686 void eval_gpu(const std::vector<array>& inputs, array& out) override;
687
692
693 private:
694 void eval(const std::vector<array>& inputs, array& out);
695};
696
698 public:
699 explicit Contiguous(Stream stream, bool allow_col_major)
700 : UnaryPrimitive(stream), allow_col_major_(allow_col_major) {}
701
702 void eval_cpu(const std::vector<array>& inputs, array& out) override;
703 void eval_gpu(const std::vector<array>& inputs, array& out) override;
704
709
710 bool is_equivalent(const Primitive& other) const override;
711
712 private:
713 bool allow_col_major_;
714};
715
717 public:
718 explicit Convolution(
720 const std::vector<int>& kernel_strides,
721 const std::vector<int>& padding,
722 const std::vector<int>& kernel_dilation,
723 const std::vector<int>& input_dilation,
724 const int groups = 1,
725 const bool flip = false)
727 padding_(padding),
728 kernel_strides_(kernel_strides),
729 kernel_dilation_(kernel_dilation),
730 input_dilation_(input_dilation),
731 groups_(groups),
732 flip_(flip) {}
733
734 void eval_cpu(const std::vector<array>& inputs, array& out) override;
735 void eval_gpu(const std::vector<array>& inputs, array& out) override;
736
737 std::vector<array> vjp(
738 const std::vector<array>& primals,
739 const std::vector<array>& cotangents,
740 const std::vector<int>& argnums,
741 const std::vector<array>& outputs) override;
742
744 bool is_equivalent(const Primitive& other) const override;
745 auto state() const {
746 return std::make_tuple(
747 padding_,
748 kernel_strides_,
749 kernel_dilation_,
750 input_dilation_,
751 groups_,
752 flip_);
753 }
754
755 private:
756 std::vector<int> padding_;
757 std::vector<int> kernel_strides_;
758 std::vector<int> kernel_dilation_;
759 std::vector<int> input_dilation_;
760 int groups_;
761 bool flip_;
762
763 void eval(const std::vector<array>& inputs, array& out);
764};
765
766class Copy : public UnaryPrimitive {
767 public:
769
770 void eval_cpu(const std::vector<array>& inputs, array& out) override;
771 void eval_gpu(const std::vector<array>& inputs, array& out) override;
772
778
779 private:
780 void eval(const std::vector<array>& inputs, array& out);
781};
782
783class Cos : public UnaryPrimitive {
784 public:
786
787 void eval_cpu(const std::vector<array>& inputs, array& out) override;
788 void eval_gpu(const std::vector<array>& inputs, array& out) override;
789
795
796 private:
797 void eval(const std::vector<array>& inputs, array& out);
798};
799
800class Cosh : public UnaryPrimitive {
801 public:
803
804 void eval_cpu(const std::vector<array>& inputs, array& out) override;
805 void eval_gpu(const std::vector<array>& inputs, array& out) override;
806
812
813 private:
814 void eval(const std::vector<array>& inputs, array& out);
815};
816
818 public:
821 int num_outputs,
822 std::function<std::vector<array>(
823 const std::vector<array>&,
824 const std::vector<array>&,
825 const std::vector<array>&)> vjp,
826 std::function<std::vector<array>(
827 const std::vector<array>&,
828 const std::vector<array>&,
829 const std::vector<int>&)> jvp,
830 std::function<std::pair<std::vector<array>, std::vector<int>>(
831 const std::vector<array>&,
832 const std::vector<int>&)> vmap)
833 : Primitive(stream),
834 num_outputs_(num_outputs),
835 vjp_fun_(std::move(vjp)),
836 jvp_fun_(std::move(jvp)),
837 vmap_fun_(std::move(vmap)) {}
838
839 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
840 override;
841 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
842 override;
843
847
848 private:
849 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
850
851 int num_outputs_;
852
853 std::function<std::vector<array>(
854 const std::vector<array>&,
855 const std::vector<array>&,
856 const std::vector<array>&)>
857 vjp_fun_;
858 std::function<std::vector<array>(
859 const std::vector<array>&,
860 const std::vector<array>&,
861 const std::vector<int>&)>
862 jvp_fun_;
863 std::function<std::pair<std::vector<array>, std::vector<int>>(
864 const std::vector<array>&,
865 const std::vector<int>&)>
866 vmap_fun_;
867};
868
869class Depends : public Primitive {
870 public:
872
873 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
874 override;
875 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
876 override;
877
878 std::vector<array> vjp(
879 const std::vector<array>& primals,
880 const std::vector<array>& cotan,
881 const std::vector<int>& argnums,
882 const std::vector<array>& outputs) override;
883
885
886 private:
887 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
888};
889
890class Divide : public UnaryPrimitive {
891 public:
893
894 void eval_cpu(const std::vector<array>& inputs, array& out) override;
895 void eval_gpu(const std::vector<array>& inputs, array& out) override;
896
902
903 private:
904 void eval(const std::vector<array>& inputs, array& out);
905};
906
907class DivMod : public Primitive {
908 public:
910
911 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
912 override;
913 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
914 override;
915
920 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
921 return std::vector{inputs[0].shape(), inputs[0].shape()};
922 }
923
924 private:
925 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
926};
927
928class Select : public UnaryPrimitive {
929 public:
931
932 void eval_cpu(const std::vector<array>& inputs, array& out) override;
933 void eval_gpu(const std::vector<array>& inputs, array& out) override;
934
940
941 private:
942 void eval(const std::vector<array>& inputs, array& out);
943};
944
945class Remainder : public UnaryPrimitive {
946 public:
948
949 void eval_cpu(const std::vector<array>& inputs, array& out) override;
950 void eval_gpu(const std::vector<array>& inputs, array& out) override;
951
957
958 private:
959 void eval(const std::vector<array>& inputs, array& out);
960};
961
962class Equal : public UnaryPrimitive {
963 public:
964 explicit Equal(Stream stream, bool equal_nan = false)
965 : UnaryPrimitive(stream), equal_nan_(equal_nan) {}
966
967 void eval_cpu(const std::vector<array>& inputs, array& out) override;
968 void eval_gpu(const std::vector<array>& inputs, array& out) override;
969
974
975 void print(std::ostream& os) override {
976 if (equal_nan_) {
977 os << "NaNEqual";
978 } else {
979 os << "Equal";
980 }
981 }
982 auto state() const {
983 return equal_nan_;
984 };
985
986 private:
987 void eval(const std::vector<array>& inputs, array& out);
988 bool equal_nan_;
989};
990
991class Erf : public UnaryPrimitive {
992 public:
994
995 void eval_cpu(const std::vector<array>& inputs, array& out) override;
996 void eval_gpu(const std::vector<array>& inputs, array& out) override;
997
1003
1004 private:
1005 void eval(const std::vector<array>& inputs, array& out);
1006};
1007
1008class ErfInv : public UnaryPrimitive {
1009 public:
1011
1012 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1013 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1014
1020
1021 private:
1022 void eval(const std::vector<array>& inputs, array& out);
1023};
1024
1025class Exp : public UnaryPrimitive {
1026 public:
1028
1029 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1030 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1031
1037
1038 private:
1039 void eval(const std::vector<array>& inputs, array& out);
1040};
1041
1042class Expm1 : public UnaryPrimitive {
1043 public:
1045
1046 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1047 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1048
1053
1054 private:
1055 void eval(const std::vector<array>& inputs, array& out);
1056};
1057
1059 public:
1060 explicit ExpandDims(Stream stream, std::vector<int> axes)
1061 : UnaryPrimitive(stream), axes_(std::move(axes)) {}
1062
1063 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1064 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1065
1069
1070 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1071 bool is_equivalent(const Primitive& other) const override;
1072
1073 static Shape output_shape(const array& input, const std::vector<int>& axes);
1074 auto state() const {
1075 return axes_;
1076 }
1077
1078 private:
1079 void eval(const std::vector<array>& inputs, array& out);
1080 std::vector<int> axes_;
1081};
1082
1083class FFT : public UnaryPrimitive {
1084 public:
1085 explicit FFT(
1086 Stream stream,
1087 const std::vector<size_t>& axes,
1088 bool inverse,
1089 bool real)
1090 : UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {}
1091
1092 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1093 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1094
1098
1099 bool is_equivalent(const Primitive& other) const override;
1100 auto state() const {
1101 return std::make_tuple(axes_, inverse_, real_);
1102 }
1103
1104 private:
1105 std::vector<size_t> axes_;
1106 bool inverse_;
1107 bool real_;
1108
1109 void eval(const std::vector<array>& inputs, array& out);
1110};
1111
1112class Flatten : public UnaryPrimitive {
1113 public:
1114 explicit Flatten(Stream stream, int start_axis, int end_axis)
1115 : UnaryPrimitive(stream), start_axis_(start_axis), end_axis_(end_axis) {}
1116
1117 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1118 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1119
1123 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1124 bool is_equivalent(const Primitive& other) const override;
1125
1126 static Shape output_shape(const array& input, int start_axis, int end_axis);
1127 auto state() const {
1128 return std::make_pair(start_axis_, end_axis_);
1129 }
1130
1131 private:
1132 int start_axis_;
1133 int end_axis_;
1134 void eval(const std::vector<array>& inputs, array& out);
1135};
1136
1137class Floor : public UnaryPrimitive {
1138 public:
1140
1141 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1142 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1143
1149
1150 private:
1151 void eval(const std::vector<array>& inputs, array& out);
1152};
1153
1154class Full : public UnaryPrimitive {
1155 public:
1157
1158 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1159 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1160
1165
1166 private:
1167 void eval(const std::vector<array>& inputs, array& out);
1168};
1169
1170class Gather : public UnaryPrimitive {
1171 public:
1172 explicit Gather(Stream stream, std::vector<int> axes, Shape slice_sizes)
1174 axes_(std::move(axes)),
1175 slice_sizes_(std::move(slice_sizes)) {}
1176
1177 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1178 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1179
1183 bool is_equivalent(const Primitive& other) const override;
1184 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1185 std::pair<std::vector<int>, std::vector<int>> state() const {
1186 return {axes_, slice_sizes_};
1187 }
1188
1189 private:
1190 void eval(const std::vector<array>& inputs, array& out);
1191 std::vector<int> axes_;
1192 Shape slice_sizes_;
1193};
1194
1195class Greater : public UnaryPrimitive {
1196 public:
1198
1199 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1200 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1201
1207
1208 private:
1209 void eval(const std::vector<array>& inputs, array& out);
1210};
1211
1213 public:
1215
1216 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1217 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1218
1224
1225 private:
1226 void eval(const std::vector<array>& inputs, array& out);
1227};
1228
1229class Hadamard : public UnaryPrimitive {
1230 public:
1231 explicit Hadamard(Stream stream, float scale)
1232 : UnaryPrimitive(stream), scale_(scale) {}
1233
1234 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1235 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1236
1241
1242 bool is_equivalent(const Primitive& other) const override;
1243 auto state() const {
1244 return scale_;
1245 }
1246
1247 private:
1248 float scale_;
1249
1250 void eval(const std::vector<array>& inputs, array& out);
1251};
1252
1253class Imag : public UnaryPrimitive {
1254 public:
1256
1257 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1258 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1259
1265};
1266
1267class Less : public UnaryPrimitive {
1268 public:
1270
1271 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1272 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1273
1279
1280 private:
1281 void eval(const std::vector<array>& inputs, array& out);
1282};
1283
1285 public:
1287
1288 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1289 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1290
1296
1297 private:
1298 void eval(const std::vector<array>& inputs, array& out);
1299};
1300
1301class Load : public UnaryPrimitive {
1302 public:
1303 explicit Load(
1304 Stream stream,
1305 std::shared_ptr<io::Reader> reader,
1306 size_t offset,
1307 bool swap_endianness = false)
1309 reader_(std::move(reader)),
1310 offset_(offset),
1311 swap_endianness_(swap_endianness) {
1312 if (stream.device == Device::gpu) {
1313 io_stream();
1314 }
1315 }
1316
1317 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1318 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1319
1321
1322 private:
1323 Stream& io_stream() {
1324 static Stream io_stream = new_stream(Device::cpu);
1325 return io_stream;
1326 };
1327 void eval(const std::vector<array>& inputs, array& out);
1328 std::shared_ptr<io::Reader> reader_;
1329 size_t offset_;
1330 bool swap_endianness_;
1331};
1332
1333class Log : public UnaryPrimitive {
1334 public:
1335 enum Base { two, ten, e };
1336
1337 explicit Log(Stream stream, Base base)
1338 : UnaryPrimitive(stream), base_(base) {}
1339
1340 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1341 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1342
1347
1348 Base state() const {
1349 return base_;
1350 };
1351
1352 void print(std::ostream& os) override {
1353 switch (base_) {
1354 case e:
1355 os << "Log";
1356 break;
1357 case two:
1358 os << "Log2";
1359 break;
1360 case ten:
1361 os << "Log10";
1362 break;
1363 }
1364 }
1365
1366 private:
1367 Base base_;
1368 void eval(const std::vector<array>& inputs, array& out);
1369};
1370
1371class Log1p : public UnaryPrimitive {
1372 public:
1374
1375 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1376 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1377
1382
1383 private:
1384 void eval(const std::vector<array>& inputs, array& out);
1385};
1386
1388 public:
1390
1391 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1392 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1393
1399
1400 private:
1401 void eval(const std::vector<array>& inputs, array& out);
1402};
1403
1405 public:
1407
1408 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1409 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1410
1416
1417 private:
1418 void eval(const std::vector<array>& inputs, array& out);
1419};
1420
1422 public:
1424
1425 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1426 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1427
1433
1434 private:
1435 void eval(const std::vector<array>& inputs, array& out);
1436};
1437
1439 public:
1441
1442 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1443 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1444
1450
1451 private:
1452 void eval(const std::vector<array>& inputs, array& out);
1453};
1454
1455class Matmul : public UnaryPrimitive {
1456 public:
1458
1459 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1460 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1461
1462 std::vector<array> vjp(
1463 const std::vector<array>& primals,
1464 const std::vector<array>& cotangents,
1465 const std::vector<int>& argnums,
1466 const std::vector<array>& outputs) override;
1467
1471 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1472};
1473
1474class Maximum : public UnaryPrimitive {
1475 public:
1477
1478 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1479 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1480
1486
1487 private:
1488 void eval(const std::vector<array>& inputs, array& out);
1489};
1490
1491class Minimum : public UnaryPrimitive {
1492 public:
1494
1495 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1496 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1497
1503
1504 private:
1505 void eval(const std::vector<array>& inputs, array& out);
1506};
1507
1508class Multiply : public UnaryPrimitive {
1509 public:
1511
1512 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1513 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1514
1520
1521 private:
1522 void eval(const std::vector<array>& inputs, array& out);
1523};
1524
1525class Negative : public UnaryPrimitive {
1526 public:
1528
1529 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1530 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1531
1537
1538 private:
1539 void eval(const std::vector<array>& inputs, array& out);
1540};
1541
1542class NotEqual : public UnaryPrimitive {
1543 public:
1545
1546 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1547 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1548
1554
1555 private:
1556 void eval(const std::vector<array>& inputs, array& out);
1557};
1558
1560 public:
1562 Stream stream,
1563 std::vector<int> axes,
1564 bool inverted,
1565 Dtype dtype)
1567 axes_(std::move(axes)),
1568 inverted_(inverted),
1569 dtype_(dtype) {}
1570
1571 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1572 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1573
1576 bool is_equivalent(const Primitive& other) const override;
1577 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
1578 return {{}};
1579 }
1580 std::tuple<std::vector<int>, bool, Dtype> state() const {
1581 return {axes_, inverted_, dtype_};
1582 }
1583
1584 private:
1585 std::vector<int> axes_;
1586 bool inverted_;
1587 Dtype dtype_;
1588
1589 void eval(const std::vector<array>& inputs, array& out);
1590};
1591
1592class Pad : public UnaryPrimitive {
1593 public:
1594 explicit Pad(
1595 Stream stream,
1596 const std::vector<int>& axes,
1597 const Shape& low_pad_size,
1598 const Shape& high_pad_size)
1600 axes_(axes),
1601 low_pad_size_(low_pad_size),
1602 high_pad_size_(high_pad_size) {}
1603
1604 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1605 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1606
1610 bool is_equivalent(const Primitive& other) const override;
1611 auto state() const {
1612 return std::make_tuple(axes_, low_pad_size_, high_pad_size_);
1613 }
1614
1615 private:
1616 std::vector<int> axes_;
1617 Shape low_pad_size_;
1618 Shape high_pad_size_;
1619
1620 void eval(const std::vector<array>& inputs, array& out);
1621};
1622
1624 public:
1625 explicit Partition(Stream stream, int kth, int axis)
1626 : UnaryPrimitive(stream), kth_(kth), axis_(axis) {}
1627
1628 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1629 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1630
1635 bool is_equivalent(const Primitive& other) const override;
1636 auto state() const {
1637 return std::make_pair(kth_, axis_);
1638 };
1639
1640 private:
1641 int kth_;
1642 int axis_;
1643
1644 void eval(const std::vector<array>& inputs, array& out);
1645};
1646
1647class Power : public UnaryPrimitive {
1648 public:
1650
1651 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1652 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1653
1659
1660 private:
1661 void eval(const std::vector<array>& inputs, array& out);
1662};
1663
1665 public:
1667 Stream stream,
1668 int group_size,
1669 int bits,
1670 bool transpose)
1672 group_size_(group_size),
1673 bits_(bits),
1674 transpose_(transpose) {}
1675
1676 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1677 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1678
1682 bool is_equivalent(const Primitive& other) const override;
1683 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1684 auto state() const {
1685 return std::make_tuple(group_size_, bits_, transpose_);
1686 }
1687
1688 private:
1689 int group_size_;
1690 int bits_;
1691 bool transpose_;
1692
1693 void eval(const std::vector<array>& inputs, array& out);
1694};
1695
1697 public:
1698 explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose)
1700 group_size_(group_size),
1701 bits_(bits),
1702 transpose_(transpose) {}
1703
1704 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1705 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1706
1710 bool is_equivalent(const Primitive& other) const override;
1711 auto state() const {
1712 return std::make_tuple(group_size_, bits_, transpose_);
1713 }
1714
1715 private:
1716 int group_size_;
1717 int bits_;
1718 bool transpose_;
1719
1720 void eval(const std::vector<array>& inputs, array& out);
1721};
1722
1724 public:
1725 explicit RandomBits(Stream stream, const Shape& shape, int width)
1726 : UnaryPrimitive(stream), shape_(shape), width_(width) {}
1727
1728 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1729 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1730
1733 bool is_equivalent(const Primitive& other) const override;
1734 std::pair<std::vector<int>, int> state() const {
1735 return {shape_, width_};
1736 };
1737
1738 private:
1739 Shape shape_;
1740 int width_;
1741
1742 void eval(const std::vector<array>& inputs, array& out);
1743};
1744
1745class Real : public UnaryPrimitive {
1746 public:
1748
1749 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1750 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1751
1757};
1758
1759class Reshape : public UnaryPrimitive {
1760 public:
1761 explicit Reshape(Stream stream, const Shape& shape)
1762 : UnaryPrimitive(stream), shape_(shape) {}
1763
1764 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1765 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1766
1770 bool is_equivalent(const Primitive& other) const override;
1771 std::vector<int> state() const {
1772 return shape_;
1773 };
1774 static Shape output_shape(const array& input, Shape shape);
1775 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1776
1777 private:
1778 Shape shape_;
1779};
1780
1781class Reduce : public UnaryPrimitive {
1782 public:
1783 enum ReduceType { And, Or, Sum, Prod, Min, Max };
1784
1785 explicit Reduce(
1786 Stream stream,
1787 ReduceType reduce_type,
1788 const std::vector<int>& axes)
1789 : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}
1790
1791 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1792 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1793
1795
1796 std::vector<array> vjp(
1797 const std::vector<array>& primals,
1798 const std::vector<array>& cotangents,
1799 const std::vector<int>& argnums,
1800 const std::vector<array>& outputs) override;
1801
1802 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1803
1804 void print(std::ostream& os) override {
1805 switch (reduce_type_) {
1806 case And:
1807 os << "And";
1808 break;
1809 case Or:
1810 os << "Or";
1811 break;
1812 case Sum:
1813 os << "Sum";
1814 break;
1815 case Prod:
1816 os << "Prod";
1817 break;
1818 case Min:
1819 os << "Min";
1820 break;
1821 case Max:
1822 os << "Max";
1823 break;
1824 }
1825 }
1826 bool is_equivalent(const Primitive& other) const override;
1827 std::pair<ReduceType, std::vector<int>> state() const {
1828 return {reduce_type_, axes_};
1829 };
1830
1831 private:
1832 ReduceType reduce_type_;
1833 std::vector<int> axes_;
1834
1835 void eval(const std::vector<array>& inputs, array& out);
1836};
1837
1838class Round : public UnaryPrimitive {
1839 public:
1841
1842 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1843 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1844
1850
1851 private:
1852 void eval(const std::vector<array>& inputs, array& out);
1853};
1854
1855class Scan : public UnaryPrimitive {
1856 public:
1858
1859 explicit Scan(
1860 Stream stream,
1861 ReduceType reduce_type,
1862 int axis,
1863 bool reverse,
1864 bool inclusive)
1866 reduce_type_(reduce_type),
1867 axis_(axis),
1868 reverse_(reverse),
1869 inclusive_(inclusive) {}
1870
1871 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1872 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1873
1876
1877 void print(std::ostream& os) override {
1878 os << "Cum";
1879 switch (reduce_type_) {
1880 case Sum:
1881 os << "Sum";
1882 break;
1883 case Prod:
1884 os << "Prod";
1885 break;
1886 case Min:
1887 os << "Min";
1888 break;
1889 case Max:
1890 os << "Max";
1891 break;
1892 }
1893 }
1894 bool is_equivalent(const Primitive& other) const override;
1895 auto state() const {
1896 return std::make_tuple(reduce_type_, axis_, reverse_, inclusive_);
1897 }
1898
1899 private:
1900 ReduceType reduce_type_;
1901 int axis_;
1902 bool reverse_;
1903 bool inclusive_;
1904
1905 void eval(const std::vector<array>& inputs, array& out);
1906};
1907
1908class Scatter : public UnaryPrimitive {
1909 public:
1911
1912 explicit Scatter(
1913 Stream stream,
1914 ReduceType reduce_type,
1915 const std::vector<int>& axes)
1916 : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}
1917
1918 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1919 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1920
1923
1924 void print(std::ostream& os) override {
1925 os << "Scatter";
1926 switch (reduce_type_) {
1927 case Sum:
1928 os << " Sum";
1929 break;
1930 case Prod:
1931 os << " Prod";
1932 break;
1933 case Min:
1934 os << " Min";
1935 break;
1936 case Max:
1937 os << " Max";
1938 break;
1939 case None:
1940 break;
1941 }
1942 }
1943 bool is_equivalent(const Primitive& other) const override;
1944 std::pair<ReduceType, std::vector<int>> state() const {
1945 return {reduce_type_, axes_};
1946 };
1947
1948 private:
1949 void eval(const std::vector<array>& inputs, array& out);
1950 ReduceType reduce_type_;
1951 std::vector<int> axes_;
1952};
1953
1954class Sigmoid : public UnaryPrimitive {
1955 public:
1957
1958 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1959 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1960
1966
1967 private:
1968 void eval(const std::vector<array>& inputs, array& out);
1969};
1970
1971class Sign : public UnaryPrimitive {
1972 public:
1974
1975 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1976 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1977
1983
1984 private:
1985 void eval(const std::vector<array>& inputs, array& out);
1986};
1987
1988class Sin : public UnaryPrimitive {
1989 public:
1991
1992 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1993 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1994
2000
2001 private:
2002 void eval(const std::vector<array>& inputs, array& out);
2003};
2004
2005class Sinh : public UnaryPrimitive {
2006 public:
2008
2009 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2010 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2011
2017
2018 private:
2019 void eval(const std::vector<array>& inputs, array& out);
2020};
2021
2022class Slice : public UnaryPrimitive {
2023 public:
2024 explicit Slice(
2025 Stream stream,
2026 const Shape& start_indices,
2027 const Shape& end_indices,
2028 const Shape& strides)
2030 start_indices_(start_indices),
2031 end_indices_(end_indices),
2032 strides_(strides) {}
2033
2034 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2035 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2036
2040 bool is_equivalent(const Primitive& other) const override;
2041 auto state() const {
2042 return std::make_tuple(start_indices_, end_indices_, strides_);
2043 }
2044
2045 private:
2046 Shape start_indices_;
2047 Shape end_indices_;
2048 Shape strides_;
2049
2050 void eval(const std::vector<array>& inputs, array& out);
2051};
2052
2054 public:
2055 explicit SliceUpdate(
2056 Stream stream,
2057 const Shape& start_indices,
2058 const Shape& end_indices,
2059 const Shape& strides)
2061 start_indices_(start_indices),
2062 end_indices_(end_indices),
2063 strides_(strides) {}
2064
2065 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2066 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2067
2071 bool is_equivalent(const Primitive& other) const override;
2073 auto state() const {
2074 return std::make_tuple(start_indices_, end_indices_, strides_);
2075 }
2076
2077 private:
2078 Shape start_indices_;
2079 Shape end_indices_;
2080 Shape strides_;
2081
2082 void eval(const std::vector<array>& inputs, array& out);
2083};
2084
2086 public:
2087 explicit DynamicSlice(Stream stream, std::vector<int> axes, Shape slice_size)
2089 axes_(std::move(axes)),
2090 slice_size_(std::move(slice_size)) {}
2091
2092 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2093 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2094
2098 bool is_equivalent(const Primitive& other) const override;
2099 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
2100 auto state() const {
2101 return std::make_pair(axes_, slice_size_);
2102 }
2103
2104 private:
2105 std::vector<int> axes_;
2106 Shape slice_size_;
2107};
2108
2110 public:
2111 explicit DynamicSliceUpdate(Stream stream, std::vector<int> axes)
2112 : UnaryPrimitive(stream), axes_(std::move(axes)) {}
2113
2114 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2115 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2116
2120 bool is_equivalent(const Primitive& other) const override;
2122 auto state() const {
2123 return axes_;
2124 }
2125
2126 private:
2127 std::vector<int> axes_;
2128};
2129
2130class Softmax : public UnaryPrimitive {
2131 public:
2132 explicit Softmax(Stream stream, bool precise)
2133 : UnaryPrimitive(stream), precise_(precise) {}
2134
2135 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2136 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2137
2142
2143 bool is_equivalent(const Primitive& other) const override;
2144 auto state() const {
2145 return precise_;
2146 };
2147
2148 private:
2149 void eval(const std::vector<array>& inputs, array& out);
2150 bool precise_;
2151};
2152
2153class Sort : public UnaryPrimitive {
2154 public:
2155 explicit Sort(Stream stream, int axis)
2156 : UnaryPrimitive(stream), axis_(axis) {}
2157
2158 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2159 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2160
2165 bool is_equivalent(const Primitive& other) const override;
2166 auto state() const {
2167 return axis_;
2168 }
2169
2170 private:
2171 int axis_;
2172
2173 void eval(const std::vector<array>& inputs, array& out);
2174};
2175
2176class Split : public Primitive {
2177 public:
2178 explicit Split(Stream stream, const Shape& indices, int axis)
2179 : Primitive(stream), indices_(indices), axis_(axis) {}
2180
2181 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2182 override;
2183 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2184 override;
2185
2189 bool is_equivalent(const Primitive& other) const override;
2190 std::pair<std::vector<int>, int> state() const {
2191 return {indices_, axis_};
2192 };
2193
2194 private:
2195 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
2196
2197 Shape indices_;
2198 int axis_;
2199};
2200
2201class Square : public UnaryPrimitive {
2202 public:
2204
2205 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2206 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2207
2213
2214 private:
2215 void eval(const std::vector<array>& inputs, array& out);
2216};
2217
2218class Sqrt : public UnaryPrimitive {
2219 public:
2220 explicit Sqrt(Stream stream, bool recip = false)
2221 : UnaryPrimitive(stream), recip_(recip) {}
2222
2223 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2224 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2225
2229 bool is_equivalent(const Primitive& other) const override;
2230 auto state() const {
2231 return recip_;
2232 }
2233
2234 void print(std::ostream& os) override {
2235 if (recip_) {
2236 os << "Rsqrt";
2237 } else {
2238 os << "Sqrt";
2239 }
2240 }
2241
2242 private:
2243 void eval(const std::vector<array>& inputs, array& out);
2244 bool recip_;
2245};
2246
2248 public:
2250
2251 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2252 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2253
2258
2259 private:
2260 void eval(const std::vector<array>& inputs, array& out);
2261};
2262
2263class Subtract : public UnaryPrimitive {
2264 public:
2266
2267 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2268 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2269
2275
2276 private:
2277 void eval(const std::vector<array>& inputs, array& out);
2278};
2279
2280class Squeeze : public UnaryPrimitive {
2281 public:
2282 explicit Squeeze(Stream stream, std::vector<int> axes)
2283 : UnaryPrimitive(stream), axes_(std::move(axes)) {}
2284
2285 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2286 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2287
2291
2292 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
2293 bool is_equivalent(const Primitive& other) const override;
2294
2295 static Shape output_shape(const array& input, const std::vector<int>& axes);
2296 auto state() const {
2297 return axes_;
2298 };
2299
2300 private:
2301 void eval(const std::vector<array>& inputs, array& out);
2302 std::vector<int> axes_;
2303};
2304
2305class Tan : public UnaryPrimitive {
2306 public:
2308
2309 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2310 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2311
2317
2318 private:
2319 void eval(const std::vector<array>& inputs, array& out);
2320};
2321
2322class Tanh : public UnaryPrimitive {
2323 public:
2325
2326 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2327 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2328
2334
2335 private:
2336 void eval(const std::vector<array>& inputs, array& out);
2337};
2338
2340 public:
2341 explicit Unflatten(Stream stream, int axis, Shape shape)
2342 : UnaryPrimitive(stream), axis_(axis), shape_(std::move(shape)) {}
2343
2344 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2345 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2346
2350 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
2351 bool is_equivalent(const Primitive& other) const override;
2352
2353 static Shape output_shape(const array& input, int axis, const Shape& shape);
2354 auto state() const {
2355 return std::make_pair(axis_, shape_);
2356 }
2357
2358 private:
2359 int axis_;
2360 Shape shape_;
2361 void eval(const std::vector<array>& inputs, array& out);
2362};
2363
2364class View : public UnaryPrimitive {
2365 public:
2366 explicit View(Stream stream, Dtype dtype)
2367 : UnaryPrimitive(stream), dtype_(dtype) {}
2368
2369 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2370 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2371
2373 void print(std::ostream& os) override;
2374 bool is_equivalent(const Primitive& other) const override;
2375 auto state() const {
2376 return dtype_;
2377 }
2378
2379 private:
2380 Dtype dtype_;
2381};
2382
2384 public:
2385 explicit Transpose(Stream stream, const std::vector<int>& axes)
2386 : UnaryPrimitive(stream), axes_(axes) {}
2387
2388 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2389 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2390
2394 bool is_equivalent(const Primitive& other) const override;
2395 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
2396 std::vector<int> state() const {
2397 return axes_;
2398 };
2399
2400 private:
2401 std::vector<int> axes_;
2402
2403 void eval(const std::vector<array>& inputs, array& out);
2404};
2405
2406/* QR Factorization primitive. */
2407class QRF : public Primitive {
2408 public:
2410
2411 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2412 override;
2413 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2414 override;
2415
2417
2418 private:
2419 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
2420};
2421
2422/* SVD primitive. */
2423class SVD : public Primitive {
2424 public:
2426
2427 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2428 override;
2429 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2430 override;
2431
2434
2435 private:
2436 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
2437};
2438
2439/* Matrix inversion primitive. */
2440class Inverse : public UnaryPrimitive {
2441 public:
2442 explicit Inverse(Stream stream, bool tri, bool upper)
2443 : UnaryPrimitive(stream), tri_(tri), upper_(upper) {}
2444
2445 void eval_cpu(const std::vector<array>& inputs, array& output) override;
2446 void eval_gpu(const std::vector<array>& inputs, array& output) override;
2447
2450 auto state() const {
2451 return std::make_pair(tri_, upper_);
2452 }
2453
2454 private:
2455 void eval(const std::vector<array>& inputs, array& output);
2456 bool tri_;
2457 bool upper_;
2458};
2459
2460class Cholesky : public UnaryPrimitive {
2461 public:
2462 explicit Cholesky(Stream stream, bool upper)
2463 : UnaryPrimitive(stream), upper_(upper) {}
2464
2465 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2466 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2467 auto state() const {
2468 return upper_;
2469 }
2470
2473
2474 private:
2475 void eval(const std::vector<array>& inputs, array& output);
2476 bool upper_;
2477};
2478
2479class Eigh : public Primitive {
2480 public:
2481 explicit Eigh(Stream stream, std::string uplo, bool compute_eigenvectors)
2482 : Primitive(stream),
2483 uplo_(std::move(uplo)),
2484 compute_eigenvectors_(compute_eigenvectors) {}
2485
2486 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2487 override;
2488 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2489 override;
2490
2493
2494 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
2495
2496 bool is_equivalent(const Primitive& other) const override;
2497 auto state() const {
2498 return std::make_pair(uplo_, compute_eigenvectors_);
2499 }
2500
2501 private:
2502 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
2503 std::string uplo_;
2504 bool compute_eigenvectors_;
2505};
2506
2507} // namespace mlx::core
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Abs(Stream stream)
Definition primitives.h:156
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Add(Stream stream)
Definition primitives.h:173
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::pair< float, float > state() const
Definition primitives.h:206
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
AddMM(Stream stream, float alpha, float beta)
Definition primitives.h:190
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
Arange(Stream stream, double start, double stop, double step)
Definition primitives.h:217
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::tuple< double, double, double > state() const
Definition primitives.h:226
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCos(Stream stream)
Definition primitives.h:240
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCosh(Stream stream)
Definition primitives.h:257
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcSin(Stream stream)
Definition primitives.h:274
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcSinh(Stream stream)
Definition primitives.h:291
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTan2(Stream stream)
Definition primitives.h:325
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcTan(Stream stream)
Definition primitives.h:308
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTanh(Stream stream)
Definition primitives.h:342
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::pair< int, int > state() const
Definition primitives.h:370
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArgPartition(Stream stream, int kth, int axis)
Definition primitives.h:359
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
ReduceType
Definition primitives.h:383
@ ArgMin
Definition primitives.h:384
@ ArgMax
Definition primitives.h:385
ArgReduce(Stream stream, ReduceType reduce_type, int axis)
Definition primitives.h:388
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::pair< ReduceType, int > state() const
Definition primitives.h:399
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
ArgSort(Stream stream, int axis)
Definition primitives.h:412
int state() const
Definition primitives.h:422
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:469
AsStrided(Stream stream, Shape shape, Strides strides, size_t offset)
Definition primitives.h:457
void eval_gpu(const std::vector< array > &inputs, array &out) override
AsType(Stream stream, Dtype dtype)
Definition primitives.h:434
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Dtype state() const
Definition primitives.h:445
void eval_cpu(const std::vector< array > &inputs, array &out) override
BitwiseBinary(Stream stream, Op op)
Definition primitives.h:485
void eval_cpu(const std::vector< array > &inputs, array &out) override
void print(std::ostream &os) override
Print the primitive.
Op
Definition primitives.h:483
@ RightShift
Definition primitives.h:483
@ Or
Definition primitives.h:483
@ LeftShift
Definition primitives.h:483
@ And
Definition primitives.h:483
@ Xor
Definition primitives.h:483
auto state() const
Definition primitives.h:496
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
auto state() const
Definition primitives.h:520
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
BlockMaskedMM(Stream stream, int block_size)
Definition primitives.h:506
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
BroadcastAxes(Stream stream, std::vector< int > ignore_axes={})
Definition primitives.h:552
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:566
void eval_cpu(const std::vector< array > &inputs, array &out) override
static Shape output_shape(const std::vector< array > &inputs, const std::vector< int > &ignore_axes)
Broadcast(Stream stream, const Shape &shape)
Definition primitives.h:577
static Shape output_shape(const std::vector< array > &inputs)
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< int > state() const
Definition primitives.h:589
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Ceil(Stream stream)
Definition primitives.h:601
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2467
Cholesky(Stream stream, bool upper)
Definition primitives.h:2462
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void print(std::ostream &os) override
Print the primitive.
Compiled(Stream stream, std::vector< array > inputs, std::vector< array > outputs, std::vector< array > tape, std::unordered_set< uintptr_t > constant_ids)
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
std::string lib_name() const
Definition primitives.h:645
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:671
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Concatenate(Stream stream, int axis)
Definition primitives.h:660
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Conjugate(Stream stream)
Definition primitives.h:683
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Contiguous(Stream stream, bool allow_col_major)
Definition primitives.h:699
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
Convolution(Stream stream, const std::vector< int > &kernel_strides, const std::vector< int > &padding, const std::vector< int > &kernel_dilation, const std::vector< int > &input_dilation, const int groups=1, const bool flip=false)
Definition primitives.h:718
auto state() const
Definition primitives.h:745
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
Copy(Stream stream)
Definition primitives.h:768
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Cos(Stream stream)
Definition primitives.h:785
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Cosh(Stream stream)
Definition primitives.h:802
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
virtual std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes) override
The primitive must know how to vectorize itself across the given axes.
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
The Jacobian-vector product.
CustomTransforms(Stream stream, int num_outputs, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> vjp, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< int > &)> jvp, std::function< std::pair< std::vector< array >, std::vector< int > >(const std::vector< array > &, const std::vector< int > &)> vmap)
Definition primitives.h:819
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotan, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Depends(Stream stream)
Definition primitives.h:871
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:920
DivMod(Stream stream)
Definition primitives.h:909
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Divide(Stream stream)
Definition primitives.h:892
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
DynamicSlice(Stream stream, std::vector< int > axes, Shape slice_size)
Definition primitives.h:2087
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2100
auto state() const
Definition primitives.h:2122
DynamicSliceUpdate(Stream stream, std::vector< int > axes)
Definition primitives.h:2111
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
auto state() const
Definition primitives.h:2497
Eigh(Stream stream, std::string uplo, bool compute_eigenvectors)
Definition primitives.h:2481
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:975
Equal(Stream stream, bool equal_nan=false)
Definition primitives.h:964
auto state() const
Definition primitives.h:982
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Erf(Stream stream)
Definition primitives.h:993
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ErfInv(Stream stream)
Definition primitives.h:1010
void eval_cpu(const std::vector< array > &inputs, array &out) override
Exp(Stream stream)
Definition primitives.h:1027
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
static Shape output_shape(const array &input, const std::vector< int > &axes)
auto state() const
Definition primitives.h:1074
void eval_gpu(const std::vector< array > &inputs, array &out) override
ExpandDims(Stream stream, std::vector< int > axes)
Definition primitives.h:1060
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Expm1(Stream stream)
Definition primitives.h:1044
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
FFT(Stream stream, const std::vector< size_t > &axes, bool inverse, bool real)
Definition primitives.h:1085
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1100
static Shape output_shape(const array &input, int start_axis, int end_axis)
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
Flatten(Stream stream, int start_axis, int end_axis)
Definition primitives.h:1114
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1127
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Floor(Stream stream)
Definition primitives.h:1139
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Full(Stream stream)
Definition primitives.h:1156
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::pair< std::vector< int >, std::vector< int > > state() const
Definition primitives.h:1185
Gather(Stream stream, std::vector< int > axes, Shape slice_sizes)
Definition primitives.h:1172
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, array &out) override
GatherMM(Stream stream)
Definition primitives.h:532
auto state() const
Definition primitives.h:1711
GatherQMM(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1698
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
GreaterEqual(Stream stream)
Definition primitives.h:1214
void eval_gpu(const std::vector< array > &inputs, array &out) override
Greater(Stream stream)
Definition primitives.h:1197
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
Hadamard(Stream stream, float scale)
Definition primitives.h:1231
auto state() const
Definition primitives.h:1243
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Imag(Stream stream)
Definition primitives.h:1255
void eval_gpu(const std::vector< array > &inputs, array &output) override
Inverse(Stream stream, bool tri, bool upper)
Definition primitives.h:2442
auto state() const
Definition primitives.h:2450
void eval_cpu(const std::vector< array > &inputs, array &output) override
LessEqual(Stream stream)
Definition primitives.h:1286
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Less(Stream stream)
Definition primitives.h:1269
void eval_gpu(const std::vector< array > &inputs, array &out) override
Load(Stream stream, std::shared_ptr< io::Reader > reader, size_t offset, bool swap_endianness=false)
Definition primitives.h:1303
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Log1p(Stream stream)
Definition primitives.h:1373
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogAddExp(Stream stream)
Definition primitives.h:1440
Base
Definition primitives.h:1335
@ ten
Definition primitives.h:1335
@ two
Definition primitives.h:1335
@ e
Definition primitives.h:1335
Log(Stream stream, Base base)
Definition primitives.h:1337
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1352
Base state() const
Definition primitives.h:1348
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalAnd(Stream stream)
Definition primitives.h:1406
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalNot(Stream stream)
Definition primitives.h:1389
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
LogicalOr(Stream stream)
Definition primitives.h:1423
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Matmul(Stream stream)
Definition primitives.h:1457
Maximum(Stream stream)
Definition primitives.h:1476
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Minimum(Stream stream)
Definition primitives.h:1493
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Multiply(Stream stream)
Definition primitives.h:1510
void eval_gpu(const std::vector< array > &inputs, array &out) override
Negative(Stream stream)
Definition primitives.h:1527
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
NotEqual(Stream stream)
Definition primitives.h:1544
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:1577
NumberOfElements(Stream stream, std::vector< int > axes, bool inverted, Dtype dtype)
Definition primitives.h:1561
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::tuple< std::vector< int >, bool, Dtype > state() const
Definition primitives.h:1580
auto state() const
Definition primitives.h:1611
Pad(Stream stream, const std::vector< int > &axes, const Shape &low_pad_size, const Shape &high_pad_size)
Definition primitives.h:1594
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Partition(Stream stream, int kth, int axis)
Definition primitives.h:1625
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
auto state() const
Definition primitives.h:1636
void eval_cpu(const std::vector< array > &inputs, array &out) override
Power(Stream stream)
Definition primitives.h:1649
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:48
virtual void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
virtual std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
The vector-Jacobian product.
virtual ~Primitive()=default
Primitive(const Primitive &other)=delete
Primitive(Primitive &&other)=delete
const Stream & stream()
The stream the primitive will run on.
Definition primitives.h:58
Primitive & operator=(Primitive &&other)=delete
virtual bool is_equivalent(const Primitive &other) const
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:107
Primitive & operator=(const Primitive &other)=delete
const Device & device()
The device the primitive will run on.
Definition primitives.h:53
virtual std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
The Jacobian-vector product.
virtual std::vector< Shape > output_shapes(const std::vector< array > &inputs)
Get the output shapes of the primitive.
virtual std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes)
The primitive must know how to vectorize itself across the given axes.
virtual void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
virtual void print(std::ostream &os)=0
Print the primitive.
Primitive(Stream stream)
Definition primitives.h:50
QRF(Stream stream)
Definition primitives.h:2409
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
QuantizedMatmul(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1666
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1684
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::pair< std::vector< int >, int > state() const
Definition primitives.h:1734
RandomBits(Stream stream, const Shape &shape, int width)
Definition primitives.h:1725
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Real(Stream stream)
Definition primitives.h:1747
Reduce(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1785
ReduceType
Definition primitives.h:1783
@ Min
Definition primitives.h:1783
@ Or
Definition primitives.h:1783
@ Max
Definition primitives.h:1783
@ And
Definition primitives.h:1783
@ Sum
Definition primitives.h:1783
@ Prod
Definition primitives.h:1783
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1804
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::pair< ReduceType, std::vector< int > > state() const
Definition primitives.h:1827
Remainder(Stream stream)
Definition primitives.h:947
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
static Shape output_shape(const array &input, Shape shape)
void eval_gpu(const std::vector< array > &inputs, array &out) override
Reshape(Stream stream, const Shape &shape)
Definition primitives.h:1761
std::vector< int > state() const
Definition primitives.h:1771
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Round(Stream stream)
Definition primitives.h:1840
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
SVD(Stream stream)
Definition primitives.h:2425
void eval_cpu(const std::vector< array > &inputs, array &out) override
ReduceType
Definition primitives.h:1857
@ Prod
Definition primitives.h:1857
@ Min
Definition primitives.h:1857
@ Max
Definition primitives.h:1857
@ Sum
Definition primitives.h:1857
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
auto state() const
Definition primitives.h:1895
Scan(Stream stream, ReduceType reduce_type, int axis, bool reverse, bool inclusive)
Definition primitives.h:1859
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1877
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::pair< ReduceType, std::vector< int > > state() const
Definition primitives.h:1944
ReduceType
Definition primitives.h:1910
@ Sum
Definition primitives.h:1910
@ Max
Definition primitives.h:1910
@ Prod
Definition primitives.h:1910
@ None
Definition primitives.h:1910
@ Min
Definition primitives.h:1910
void eval_cpu(const std::vector< array > &inputs, array &out) override
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1924
void eval_gpu(const std::vector< array > &inputs, array &out) override
Scatter(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1912
void eval_gpu(const std::vector< array > &inputs, array &out) override
Select(Stream stream)
Definition primitives.h:930
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sigmoid(Stream stream)
Definition primitives.h:1956
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Sign(Stream stream)
Definition primitives.h:1973
Sin(Stream stream)
Definition primitives.h:1990
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sinh(Stream stream)
Definition primitives.h:2007
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2041
Slice(Stream stream, const Shape &start_indices, const Shape &end_indices, const Shape &strides)
Definition primitives.h:2024
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
SliceUpdate(Stream stream, const Shape &start_indices, const Shape &end_indices, const Shape &strides)
Definition primitives.h:2055
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2073
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Softmax(Stream stream, bool precise)
Definition primitives.h:2132
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2144
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2166
Sort(Stream stream, int axis)
Definition primitives.h:2155
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
std::pair< std::vector< int >, int > state() const
Definition primitives.h:2190
Split(Stream stream, const Shape &indices, int axis)
Definition primitives.h:2178
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
auto state() const
Definition primitives.h:2230
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sqrt(Stream stream, bool recip=false)
Definition primitives.h:2220
void eval_gpu(const std::vector< array > &inputs, array &out) override
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:2234
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Square(Stream stream)
Definition primitives.h:2203
Squeeze(Stream stream, std::vector< int > axes)
Definition primitives.h:2282
auto state() const
Definition primitives.h:2296
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
static Shape output_shape(const array &input, const std::vector< int > &axes)
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
StopGradient(Stream stream)
Definition primitives.h:2249
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Subtract(Stream stream)
Definition primitives.h:2265
Tan(Stream stream)
Definition primitives.h:2307
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Tanh(Stream stream)
Definition primitives.h:2324
void eval_cpu(const std::vector< array > &inputs, array &out) override
Transpose(Stream stream, const std::vector< int > &axes)
Definition primitives.h:2385
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< int > state() const
Definition primitives.h:2396
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:126
UnaryPrimitive & operator=(const UnaryPrimitive &other)=delete
UnaryPrimitive(Stream stream)
An abstract base class for a primitive with a single output.
Definition primitives.h:131
virtual void eval_gpu(const std::vector< array > &inputs, array &output)=0
UnaryPrimitive(UnaryPrimitive &&other)=delete
virtual void eval_cpu(const std::vector< array > &inputs, array &output)=0
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:141
UnaryPrimitive(const UnaryPrimitive &other)=delete
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:136
UnaryPrimitive & operator=(UnaryPrimitive &&other)=delete
virtual ~UnaryPrimitive()=default
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Unflatten(Stream stream, int axis, Shape shape)
Definition primitives.h:2341
static Shape output_shape(const array &input, int axis, const Shape &shape)
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2354
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2375
void print(std::ostream &os) override
Print the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
View(Stream stream, Dtype dtype)
Definition primitives.h:2366
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition array.h:24
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
array tri(int n, int m, int k, Dtype type, StreamOrDevice s={})
array transpose(const array &a, std::vector< int > axes, StreamOrDevice s={})
Permutes the dimensions according to the given axes.
array real(const array &a, StreamOrDevice s={})
Definition allocator.h:7
std::vector< ShapeElem > Shape
Definition array.h:21
Stream new_stream(Device d)
Make a new stream on the given device.
std::vector< int64_t > Strides
Definition array.h:22
void eval(std::vector< array > outputs)
#define DEFINE_DEFAULT_IS_EQUIVALENT()
Definition primitives.h:34
#define DEFINE_PRINT(PRIMITIVE)
Definition primitives.h:29
#define DEFINE_INPUT_OUTPUT_SHAPE()
Definition primitives.h:39
#define DEFINE_GRADS()
Definition primitives.h:17
#define DEFINE_VMAP()
Definition primitives.h:12
Definition device.h:7
static constexpr DeviceType gpu
Definition device.h:14
static constexpr DeviceType cpu
Definition device.h:13
Definition dtype.h:13
Definition stream.h:9
Device device
Definition stream.h:11