MLX
 
Loading...
Searching...
No Matches
primitives.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <unordered_set>
6
7#include "mlx/array.h"
8#include "mlx/device.h"
9#include "mlx/io/load.h"
10#include "mlx/stream.h"
11
12#define DEFINE_VMAP() \
13 virtual std::pair<std::vector<array>, std::vector<int>> vmap( \
14 const std::vector<array>& inputs, const std::vector<int>& axes) \
15 override;
16
17#define DEFINE_GRADS() \
18 std::vector<array> jvp( \
19 const std::vector<array>& primals, \
20 const std::vector<array>& tangents, \
21 const std::vector<int>& argnums) override; \
22 \
23 std::vector<array> vjp( \
24 const std::vector<array>& primals, \
25 const std::vector<array>& cotangents, \
26 const std::vector<int>& argnums, \
27 const std::vector<array>& outputs) override;
28
29#define DEFINE_PRINT(PRIMITIVE) \
30 void print(std::ostream& os) override { \
31 os << #PRIMITIVE; \
32 }
33
34#define DEFINE_DEFAULT_IS_EQUIVALENT() \
35 bool is_equivalent(const Primitive& other) const override { \
36 return true; \
37 }
38
39#define DEFINE_INPUT_OUTPUT_SHAPE() \
40 std::vector<Shape> output_shapes(const std::vector<array>& inputs) \
41 override { \
42 return {inputs[0].shape()}; \
43 }
44
45namespace mlx::core {
46
47// Abstract base class
48class Primitive {
49 public:
50 explicit Primitive(Stream stream) : stream_(stream) {}
51
53 const Device& device() {
54 return stream().device;
55 }
56
58 const Stream& stream() {
59 return stream_;
60 }
61
69 virtual void eval_cpu(
70 const std::vector<array>& inputs,
71 std::vector<array>& outputs) = 0;
72 virtual void eval_gpu(
73 const std::vector<array>& inputs,
74 std::vector<array>& outputs) = 0;
75
79 virtual std::vector<array> jvp(
80 const std::vector<array>& primals,
81 const std::vector<array>& tangents,
82 const std::vector<int>& argnums);
83
87 virtual std::vector<array> vjp(
88 const std::vector<array>& primals,
89 const std::vector<array>& cotangents,
90 const std::vector<int>& argnums,
91 const std::vector<array>& outputs);
92
99 virtual std::pair<std::vector<array>, std::vector<int>> vmap(
100 const std::vector<array>& inputs,
101 const std::vector<int>& axes);
102
104 virtual void print(std::ostream& os) = 0;
105
107 virtual bool is_equivalent(const Primitive& other) const {
108 return false;
109 }
110
113 virtual std::vector<Shape> output_shapes(const std::vector<array>& inputs);
114
115 virtual ~Primitive() = default;
116 Primitive(const Primitive& other) = delete;
117 Primitive(Primitive&& other) = delete;
118 Primitive& operator=(const Primitive& other) = delete;
119 Primitive& operator=(Primitive&& other) = delete;
120
121 private:
122 // Every primitive stores the stream it should run in
123 Stream stream_;
124};
125
126class UnaryPrimitive : public Primitive {
130 public:
132
133 virtual void eval_cpu(const std::vector<array>& inputs, array& output) = 0;
134 virtual void eval_gpu(const std::vector<array>& inputs, array& output) = 0;
135
136 inline void eval_cpu(
137 const std::vector<array>& inputs,
138 std::vector<array>& outputs) override {
139 eval_cpu(inputs, outputs[0]);
140 }
141 inline void eval_gpu(
142 const std::vector<array>& inputs,
143 std::vector<array>& outputs) override {
144 eval_gpu(inputs, outputs[0]);
145 }
146
147 virtual ~UnaryPrimitive() = default;
148 UnaryPrimitive(const UnaryPrimitive& other) = delete;
150 UnaryPrimitive& operator=(const UnaryPrimitive& other) = delete;
152};
153
154class Abs : public UnaryPrimitive {
155 public:
157
158 void eval_cpu(const std::vector<array>& inputs, array& out) override;
159 void eval_gpu(const std::vector<array>& inputs, array& out) override;
160
166};
167
168class Add : public UnaryPrimitive {
169 public:
171
172 void eval_cpu(const std::vector<array>& inputs, array& out) override;
173 void eval_gpu(const std::vector<array>& inputs, array& out) override;
174
180};
181
182class AddMM : public UnaryPrimitive {
183 public:
184 explicit AddMM(Stream stream, float alpha, float beta)
185 : UnaryPrimitive(stream), alpha_(alpha), beta_(beta) {}
186
187 void eval_cpu(const std::vector<array>& inputs, array& out) override;
188 void eval_gpu(const std::vector<array>& inputs, array& out) override;
189
193
194 bool is_equivalent(const Primitive& other) const override;
195 std::pair<float, float> state() const {
196 return {alpha_, beta_};
197 };
198
199 private:
200 const float alpha_;
201 const float beta_;
202};
203
204class Arange : public UnaryPrimitive {
205 public:
206 explicit Arange(Stream stream, double start, double stop, double step)
207 : UnaryPrimitive(stream), start_(start), stop_(stop), step_(step) {}
208
209 void eval_cpu(const std::vector<array>& inputs, array& out) override;
210 void eval_gpu(const std::vector<array>& inputs, array& out) override;
211
213 bool is_equivalent(const Primitive& other) const override;
214 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
215 std::tuple<double, double, double> state() const {
216 return {start_, stop_, step_};
217 };
218
219 private:
220 double start_;
221 double stop_;
222 double step_;
223};
224
225class ArcCos : public UnaryPrimitive {
226 public:
228
229 void eval_cpu(const std::vector<array>& inputs, array& out) override;
230 void eval_gpu(const std::vector<array>& inputs, array& out) override;
231
237};
238
239class ArcCosh : public UnaryPrimitive {
240 public:
242
243 void eval_cpu(const std::vector<array>& inputs, array& out) override;
244 void eval_gpu(const std::vector<array>& inputs, array& out) override;
245
251};
252
253class ArcSin : public UnaryPrimitive {
254 public:
256
257 void eval_cpu(const std::vector<array>& inputs, array& out) override;
258 void eval_gpu(const std::vector<array>& inputs, array& out) override;
259
265};
266
267class ArcSinh : public UnaryPrimitive {
268 public:
270
271 void eval_cpu(const std::vector<array>& inputs, array& out) override;
272 void eval_gpu(const std::vector<array>& inputs, array& out) override;
273
279};
280
281class ArcTan : public UnaryPrimitive {
282 public:
284
285 void eval_cpu(const std::vector<array>& inputs, array& out) override;
286 void eval_gpu(const std::vector<array>& inputs, array& out) override;
287
293};
294
295class ArcTan2 : public UnaryPrimitive {
296 public:
298
299 void eval_cpu(const std::vector<array>& inputs, array& out) override;
300 void eval_gpu(const std::vector<array>& inputs, array& out) override;
301
307};
308
309class ArcTanh : public UnaryPrimitive {
310 public:
312
313 void eval_cpu(const std::vector<array>& inputs, array& out) override;
314 void eval_gpu(const std::vector<array>& inputs, array& out) override;
315
321};
322
324 public:
325 explicit ArgPartition(Stream stream, int kth, int axis)
326 : UnaryPrimitive(stream), kth_(kth), axis_(axis) {}
327
328 void eval_cpu(const std::vector<array>& inputs, array& out) override;
329 void eval_gpu(const std::vector<array>& inputs, array& out) override;
330
335 bool is_equivalent(const Primitive& other) const override;
336 std::pair<int, int> state() const {
337 return {kth_, axis_};
338 };
339
340 private:
341 int kth_;
342 int axis_;
343};
344
345class ArgReduce : public UnaryPrimitive {
346 public:
351
352 explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis)
353 : UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {}
354
355 void eval_cpu(const std::vector<array>& inputs, array& out) override;
356 void eval_gpu(const std::vector<array>& inputs, array& out) override;
357
361 bool is_equivalent(const Primitive& other) const override;
362 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
363 std::pair<ReduceType, int> state() const {
364 return {reduce_type_, axis_};
365 };
366
367 private:
368 ReduceType reduce_type_;
369 int axis_;
370};
371
372class ArgSort : public UnaryPrimitive {
373 public:
374 explicit ArgSort(Stream stream, int axis)
375 : UnaryPrimitive(stream), axis_(axis) {}
376
377 void eval_cpu(const std::vector<array>& inputs, array& out) override;
378 void eval_gpu(const std::vector<array>& inputs, array& out) override;
379
383 bool is_equivalent(const Primitive& other) const override;
384 int state() const {
385 return axis_;
386 };
387
388 private:
389 int axis_;
390};
391
392class AsType : public UnaryPrimitive {
393 public:
394 explicit AsType(Stream stream, Dtype dtype)
395 : UnaryPrimitive(stream), dtype_(dtype) {}
396
397 void eval_cpu(const std::vector<array>& inputs, array& out) override;
398 void eval_gpu(const std::vector<array>& inputs, array& out) override;
399
404 bool is_equivalent(const Primitive& other) const override;
405 Dtype state() const {
406 return dtype_;
407 };
408
409 private:
410 Dtype dtype_;
411};
412
413class AsStrided : public UnaryPrimitive {
414 public:
415 explicit AsStrided(Stream stream, Shape shape, Strides strides, size_t offset)
417 shape_(std::move(shape)),
418 strides_(std::move(strides)),
419 offset_(offset) {}
420
421 void eval_cpu(const std::vector<array>& inputs, array& out) override;
422 void eval_gpu(const std::vector<array>& inputs, array& out) override;
423
426 bool is_equivalent(const Primitive& other) const override;
427 auto state() const {
428 return std::make_tuple(shape_, strides_, offset_);
429 }
430
431 private:
432 Shape shape_;
433 Strides strides_;
434 size_t offset_;
435
436 void eval(const std::vector<array>& inputs, array& out);
437};
438
440 public:
442
444 : UnaryPrimitive(stream), op_(op) {}
445
446 void eval_cpu(const std::vector<array>& inputs, array& out) override;
447 void eval_gpu(const std::vector<array>& inputs, array& out) override;
448
451 bool is_equivalent(const Primitive& other) const override;
452 void print(std::ostream& os) override;
454 auto state() const {
455 return op_;
456 }
457
458 private:
459 Op op_;
460};
461
463 public:
465
466 void eval_cpu(const std::vector<array>& inputs, array& out) override;
467 void eval_gpu(const std::vector<array>& inputs, array& out) override;
468
473};
474
476 public:
477 explicit BlockMaskedMM(Stream stream, int block_size)
478 : UnaryPrimitive(stream), block_size_(block_size) {}
479
480 void eval_cpu(const std::vector<array>& inputs, array& out) override;
481 void eval_gpu(const std::vector<array>& inputs, array& out) override;
482
483 std::vector<array> vjp(
484 const std::vector<array>& primals,
485 const std::vector<array>& cotangents,
486 const std::vector<int>& argnums,
487 const std::vector<array>& outputs) override;
488
490 bool is_equivalent(const Primitive& other) const override;
491 auto state() const {
492 return block_size_;
493 }
494
495 private:
496 int block_size_;
497};
498
499class GatherMM : public UnaryPrimitive {
500 public:
502
503 void eval_cpu(const std::vector<array>& inputs, array& out) override;
504 void eval_gpu(const std::vector<array>& inputs, array& out) override;
505
506 std::vector<array> vjp(
507 const std::vector<array>& primals,
508 const std::vector<array>& cotangents,
509 const std::vector<int>& argnums,
510 const std::vector<array>& outputs) override;
511
514};
515
517 public:
518 explicit BroadcastAxes(Stream stream, std::vector<int> ignore_axes = {})
519 : UnaryPrimitive(stream), ignore_axes_(std::move(ignore_axes)) {}
520
521 void eval_cpu(const std::vector<array>& inputs, array& out) override;
522 void eval_gpu(const std::vector<array>& inputs, array& out) override;
523
527 bool is_equivalent(const Primitive& other) const override;
529 const std::vector<array>& inputs,
530 const std::vector<int>& ignore_axes);
531 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
532 auto state() const {
533 return ignore_axes_;
534 }
535
536 private:
537 void eval(const std::vector<array>& inputs, array& out);
538 std::vector<int> ignore_axes_;
539};
540
541class Broadcast : public UnaryPrimitive {
542 public:
543 explicit Broadcast(Stream stream, const Shape& shape)
544 : UnaryPrimitive(stream), shape_(shape) {}
545
546 void eval_cpu(const std::vector<array>& inputs, array& out) override;
547 void eval_gpu(const std::vector<array>& inputs, array& out) override;
548
552 static Shape output_shape(const std::vector<array>& inputs);
553 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
554 bool is_equivalent(const Primitive& other) const override;
555 std::vector<int> state() const {
556 return shape_;
557 };
558
559 private:
560 Shape shape_;
561
562 void eval(const std::vector<array>& inputs, array& out);
563};
564
565class Ceil : public UnaryPrimitive {
566 public:
568
569 void eval_cpu(const std::vector<array>& inputs, array& out) override;
570 void eval_gpu(const std::vector<array>& inputs, array& out) override;
571
577};
578
579class Compiled : public Primitive {
580 public:
581 /*
582 * The inputs, outputs and tape are either tracers or constants.
583 * - The tape should not contain the inputs, but it should contain the
584 * outputs.
585 * - The tape should also have only one array per primitive for multi-output
586 * primitives.
587 * - The constant_ids contains ids of arrays in the input list that are safe
588 * to treat as scalar constants.
589 */
590 explicit Compiled(
592 std::vector<array> inputs,
593 std::vector<array> outputs,
594 std::vector<array> tape,
595 std::unordered_set<uintptr_t> constant_ids);
596
597 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
598 override;
599 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
600 override;
601
604 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
605 void print(std::ostream& os) override;
606 bool is_equivalent(const Primitive& other) const override;
607
608 std::string lib_name() const {
609 return kernel_lib_;
610 }
611
612 private:
613 const std::vector<array> inputs_;
614 const std::vector<array> outputs_;
615 const std::vector<array> tape_;
616 const std::unordered_set<uintptr_t> constant_ids_;
617
618 std::string kernel_lib_;
619};
620
622 public:
623 explicit Concatenate(Stream stream, int axis)
624 : UnaryPrimitive(stream), axis_(axis) {}
625
626 void eval_cpu(const std::vector<array>& inputs, array& out) override;
627 void eval_gpu(const std::vector<array>& inputs, array& out) override;
628
632 bool is_equivalent(const Primitive& other) const override;
633 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
634 auto state() const {
635 return axis_;
636 }
637
638 private:
639 int axis_;
640};
641
642class Conjugate : public UnaryPrimitive {
643 public:
645
646 void eval_cpu(const std::vector<array>& inputs, array& out) override;
647 void eval_gpu(const std::vector<array>& inputs, array& out) override;
648
653};
654
656 public:
657 explicit Contiguous(Stream stream, bool allow_col_major)
658 : UnaryPrimitive(stream), allow_col_major_(allow_col_major) {}
659
660 void eval_cpu(const std::vector<array>& inputs, array& out) override;
661 void eval_gpu(const std::vector<array>& inputs, array& out) override;
662
667
668 bool is_equivalent(const Primitive& other) const override;
669
670 private:
671 bool allow_col_major_;
672};
673
675 public:
676 explicit Convolution(
678 const std::vector<int>& kernel_strides,
679 const std::vector<int>& padding,
680 const std::vector<int>& kernel_dilation,
681 const std::vector<int>& input_dilation,
682 const int groups = 1,
683 const bool flip = false)
685 padding_(padding),
686 kernel_strides_(kernel_strides),
687 kernel_dilation_(kernel_dilation),
688 input_dilation_(input_dilation),
689 groups_(groups),
690 flip_(flip) {}
691
692 void eval_cpu(const std::vector<array>& inputs, array& out) override;
693 void eval_gpu(const std::vector<array>& inputs, array& out) override;
694
695 std::vector<array> vjp(
696 const std::vector<array>& primals,
697 const std::vector<array>& cotangents,
698 const std::vector<int>& argnums,
699 const std::vector<array>& outputs) override;
700
702 bool is_equivalent(const Primitive& other) const override;
703 auto state() const {
704 return std::make_tuple(
705 padding_,
706 kernel_strides_,
707 kernel_dilation_,
708 input_dilation_,
709 groups_,
710 flip_);
711 }
712
713 private:
714 std::vector<int> padding_;
715 std::vector<int> kernel_strides_;
716 std::vector<int> kernel_dilation_;
717 std::vector<int> input_dilation_;
718 int groups_;
719 bool flip_;
720};
721
722class Copy : public UnaryPrimitive {
723 public:
725
726 void eval_cpu(const std::vector<array>& inputs, array& out) override;
727 void eval_gpu(const std::vector<array>& inputs, array& out) override;
728
734
735 private:
736 void eval(const std::vector<array>& inputs, array& out);
737};
738
739class Cos : public UnaryPrimitive {
740 public:
742
743 void eval_cpu(const std::vector<array>& inputs, array& out) override;
744 void eval_gpu(const std::vector<array>& inputs, array& out) override;
745
751};
752
753class Cosh : public UnaryPrimitive {
754 public:
756
757 void eval_cpu(const std::vector<array>& inputs, array& out) override;
758 void eval_gpu(const std::vector<array>& inputs, array& out) override;
759
765};
766
768 public:
771 int num_outputs,
772 std::function<std::vector<array>(
773 const std::vector<array>&,
774 const std::vector<array>&,
775 const std::vector<array>&)> vjp,
776 std::function<std::vector<array>(
777 const std::vector<array>&,
778 const std::vector<array>&,
779 const std::vector<int>&)> jvp,
780 std::function<std::pair<std::vector<array>, std::vector<int>>(
781 const std::vector<array>&,
782 const std::vector<int>&)> vmap)
783 : Primitive(stream),
784 num_outputs_(num_outputs),
785 vjp_fun_(std::move(vjp)),
786 jvp_fun_(std::move(jvp)),
787 vmap_fun_(std::move(vmap)) {}
788
789 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
790 override;
791 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
792 override;
793
797
798 private:
799 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
800
801 int num_outputs_;
802
803 std::function<std::vector<array>(
804 const std::vector<array>&,
805 const std::vector<array>&,
806 const std::vector<array>&)>
807 vjp_fun_;
808 std::function<std::vector<array>(
809 const std::vector<array>&,
810 const std::vector<array>&,
811 const std::vector<int>&)>
812 jvp_fun_;
813 std::function<std::pair<std::vector<array>, std::vector<int>>(
814 const std::vector<array>&,
815 const std::vector<int>&)>
816 vmap_fun_;
817};
818
819class Depends : public Primitive {
820 public:
822
823 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
824 override;
825 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
826 override;
827
828 std::vector<array> vjp(
829 const std::vector<array>& primals,
830 const std::vector<array>& cotan,
831 const std::vector<int>& argnums,
832 const std::vector<array>& outputs) override;
833
835
836 private:
837 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
838};
839
840class Divide : public UnaryPrimitive {
841 public:
843
844 void eval_cpu(const std::vector<array>& inputs, array& out) override;
845 void eval_gpu(const std::vector<array>& inputs, array& out) override;
846
852};
853
854class DivMod : public Primitive {
855 public:
857
858 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
859 override;
860 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
861 override;
862
867 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
868 return std::vector{inputs[0].shape(), inputs[0].shape()};
869 }
870};
871
872class Select : public UnaryPrimitive {
873 public:
875
876 void eval_cpu(const std::vector<array>& inputs, array& out) override;
877 void eval_gpu(const std::vector<array>& inputs, array& out) override;
878
884};
885
886class Remainder : public UnaryPrimitive {
887 public:
889
890 void eval_cpu(const std::vector<array>& inputs, array& out) override;
891 void eval_gpu(const std::vector<array>& inputs, array& out) override;
892
898};
899
900class Equal : public UnaryPrimitive {
901 public:
902 explicit Equal(Stream stream, bool equal_nan = false)
903 : UnaryPrimitive(stream), equal_nan_(equal_nan) {}
904
905 void eval_cpu(const std::vector<array>& inputs, array& out) override;
906 void eval_gpu(const std::vector<array>& inputs, array& out) override;
907
912
913 void print(std::ostream& os) override {
914 if (equal_nan_) {
915 os << "NaNEqual";
916 } else {
917 os << "Equal";
918 }
919 }
920 auto state() const {
921 return equal_nan_;
922 };
923
924 private:
925 bool equal_nan_;
926};
927
928class Erf : public UnaryPrimitive {
929 public:
931
932 void eval_cpu(const std::vector<array>& inputs, array& out) override;
933 void eval_gpu(const std::vector<array>& inputs, array& out) override;
934
940};
941
942class ErfInv : public UnaryPrimitive {
943 public:
945
946 void eval_cpu(const std::vector<array>& inputs, array& out) override;
947 void eval_gpu(const std::vector<array>& inputs, array& out) override;
948
954};
955
956class Exp : public UnaryPrimitive {
957 public:
959
960 void eval_cpu(const std::vector<array>& inputs, array& out) override;
961 void eval_gpu(const std::vector<array>& inputs, array& out) override;
962
968};
969
970class Expm1 : public UnaryPrimitive {
971 public:
973
974 void eval_cpu(const std::vector<array>& inputs, array& out) override;
975 void eval_gpu(const std::vector<array>& inputs, array& out) override;
976
981};
982
984 public:
985 explicit ExpandDims(Stream stream, std::vector<int> axes)
986 : UnaryPrimitive(stream), axes_(std::move(axes)) {}
987
988 void eval_cpu(const std::vector<array>& inputs, array& out) override;
989 void eval_gpu(const std::vector<array>& inputs, array& out) override;
990
994
995 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
996 bool is_equivalent(const Primitive& other) const override;
997
998 static Shape output_shape(const array& input, const std::vector<int>& axes);
999 auto state() const {
1000 return axes_;
1001 }
1002
1003 private:
1004 void eval(const std::vector<array>& inputs, array& out);
1005 std::vector<int> axes_;
1006};
1007
1008class FFT : public UnaryPrimitive {
1009 public:
1010 explicit FFT(
1011 Stream stream,
1012 const std::vector<size_t>& axes,
1013 bool inverse,
1014 bool real)
1015 : UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {}
1016
1017 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1018 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1019
1023
1024 bool is_equivalent(const Primitive& other) const override;
1025 auto state() const {
1026 return std::make_tuple(axes_, inverse_, real_);
1027 }
1028
1029 private:
1030 std::vector<size_t> axes_;
1031 bool inverse_;
1032 bool real_;
1033};
1034
1035class Flatten : public UnaryPrimitive {
1036 public:
1037 explicit Flatten(Stream stream, int start_axis, int end_axis)
1038 : UnaryPrimitive(stream), start_axis_(start_axis), end_axis_(end_axis) {}
1039
1040 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1041 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1042
1046 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1047 bool is_equivalent(const Primitive& other) const override;
1048
1049 static Shape output_shape(const array& input, int start_axis, int end_axis);
1050 auto state() const {
1051 return std::make_pair(start_axis_, end_axis_);
1052 }
1053
1054 private:
1055 int start_axis_;
1056 int end_axis_;
1057 void eval(const std::vector<array>& inputs, array& out);
1058};
1059
1060class Floor : public UnaryPrimitive {
1061 public:
1063
1064 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1065 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1066
1072};
1073
1074class Full : public UnaryPrimitive {
1075 public:
1077
1078 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1079 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1080
1085};
1086
1087class Gather : public UnaryPrimitive {
1088 public:
1089 explicit Gather(Stream stream, std::vector<int> axes, Shape slice_sizes)
1091 axes_(std::move(axes)),
1092 slice_sizes_(std::move(slice_sizes)) {}
1093
1094 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1095 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1096
1100 bool is_equivalent(const Primitive& other) const override;
1101 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1102 std::pair<std::vector<int>, std::vector<int>> state() const {
1103 return {axes_, slice_sizes_};
1104 }
1105
1106 private:
1107 std::vector<int> axes_;
1108 Shape slice_sizes_;
1109};
1110
1112 public:
1113 explicit GatherAxis(Stream stream, int axis)
1114 : UnaryPrimitive(stream), axis_(axis) {}
1115
1116 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1117 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1118
1122 bool is_equivalent(const Primitive& other) const override;
1123 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1124 auto state() const {
1125 return axis_;
1126 }
1127
1128 private:
1129 int axis_;
1130};
1131
1132class Greater : public UnaryPrimitive {
1133 public:
1135
1136 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1137 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1138
1144};
1145
1147 public:
1149
1150 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1151 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1152
1158};
1159
1160class Hadamard : public UnaryPrimitive {
1161 public:
1162 explicit Hadamard(Stream stream, float scale)
1163 : UnaryPrimitive(stream), scale_(scale) {}
1164
1165 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1166 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1167
1172
1173 bool is_equivalent(const Primitive& other) const override;
1174 auto state() const {
1175 return scale_;
1176 }
1177
1178 private:
1179 float scale_;
1180};
1181
1182class Imag : public UnaryPrimitive {
1183 public:
1185
1186 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1187 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1188
1194};
1195
1196class Less : public UnaryPrimitive {
1197 public:
1199
1200 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1201 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1202
1208};
1209
1211 public:
1213
1214 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1215 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1216
1222};
1223
1224class Load : public UnaryPrimitive {
1225 public:
1226 explicit Load(
1227 Stream stream,
1228 std::shared_ptr<io::Reader> reader,
1229 size_t offset,
1230 bool swap_endianness = false)
1232 reader_(std::move(reader)),
1233 offset_(offset),
1234 swap_endianness_(swap_endianness) {}
1235
1236 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1237 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1238
1240
1241 private:
1242 std::shared_ptr<io::Reader> reader_;
1243 size_t offset_;
1244 bool swap_endianness_;
1245};
1246
1247class Log : public UnaryPrimitive {
1248 public:
1249 enum Base { two, ten, e };
1250
1251 explicit Log(Stream stream, Base base)
1252 : UnaryPrimitive(stream), base_(base) {}
1253
1254 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1255 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1256
1261
1262 Base state() const {
1263 return base_;
1264 };
1265
1266 void print(std::ostream& os) override {
1267 switch (base_) {
1268 case e:
1269 os << "Log";
1270 break;
1271 case two:
1272 os << "Log2";
1273 break;
1274 case ten:
1275 os << "Log10";
1276 break;
1277 }
1278 }
1279
1280 private:
1281 Base base_;
1282};
1283
1284class Log1p : public UnaryPrimitive {
1285 public:
1287
1288 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1289 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1290
1295};
1296
1298 public:
1300
1301 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1302 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1303
1309};
1310
1312 public:
1314
1315 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1316 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1317
1323};
1324
1326 public:
1328
1329 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1330 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1331
1337};
1338
1340 public:
1342
1343 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1344 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1345
1351};
1352
1353class Matmul : public UnaryPrimitive {
1354 public:
1356
1357 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1358 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1359
1364 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1365};
1366
1367class Maximum : public UnaryPrimitive {
1368 public:
1370
1371 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1372 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1373
1379};
1380
1381class Minimum : public UnaryPrimitive {
1382 public:
1384
1385 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1386 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1387
1393};
1394
1395class Multiply : public UnaryPrimitive {
1396 public:
1398
1399 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1400 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1401
1407};
1408
1409class Negative : public UnaryPrimitive {
1410 public:
1412
1413 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1414 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1415
1421};
1422
1423class NotEqual : public UnaryPrimitive {
1424 public:
1426
1427 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1428 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1429
1435};
1436
1438 public:
1440 Stream stream,
1441 std::vector<int> axes,
1442 bool inverted,
1443 Dtype dtype)
1445 axes_(std::move(axes)),
1446 inverted_(inverted),
1447 dtype_(dtype) {}
1448
1449 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1450 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1451
1454 bool is_equivalent(const Primitive& other) const override;
1455 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
1456 return {{}};
1457 }
1458 std::tuple<std::vector<int>, bool, Dtype> state() const {
1459 return {axes_, inverted_, dtype_};
1460 }
1461
1462 private:
1463 std::vector<int> axes_;
1464 bool inverted_;
1465 Dtype dtype_;
1466
1467 void eval(const std::vector<array>& inputs, array& out);
1468};
1469
1470class Pad : public UnaryPrimitive {
1471 public:
1472 explicit Pad(
1473 Stream stream,
1474 const std::vector<int>& axes,
1475 const Shape& low_pad_size,
1476 const Shape& high_pad_size)
1478 axes_(axes),
1479 low_pad_size_(low_pad_size),
1480 high_pad_size_(high_pad_size) {}
1481
1482 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1483 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1484
1488 bool is_equivalent(const Primitive& other) const override;
1489 auto state() const {
1490 return std::make_tuple(axes_, low_pad_size_, high_pad_size_);
1491 }
1492
1493 private:
1494 std::vector<int> axes_;
1495 Shape low_pad_size_;
1496 Shape high_pad_size_;
1497};
1498
1500 public:
1501 explicit Partition(Stream stream, int kth, int axis)
1502 : UnaryPrimitive(stream), kth_(kth), axis_(axis) {}
1503
1504 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1505 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1506
1511 bool is_equivalent(const Primitive& other) const override;
1512 auto state() const {
1513 return std::make_pair(kth_, axis_);
1514 };
1515
1516 private:
1517 int kth_;
1518 int axis_;
1519};
1520
1521class Power : public UnaryPrimitive {
1522 public:
1524
1525 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1526 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1527
1533};
1534
1536 public:
1538 Stream stream,
1539 int group_size,
1540 int bits,
1541 bool transpose)
1543 group_size_(group_size),
1544 bits_(bits),
1545 transpose_(transpose) {}
1546
1547 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1548 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1549
1553 bool is_equivalent(const Primitive& other) const override;
1554 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1555 auto state() const {
1556 return std::make_tuple(group_size_, bits_, transpose_);
1557 }
1558
1559 private:
1560 int group_size_;
1561 int bits_;
1562 bool transpose_;
1563};
1564
1566 public:
1567 explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose)
1569 group_size_(group_size),
1570 bits_(bits),
1571 transpose_(transpose) {}
1572
1573 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1574 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1575
1579 bool is_equivalent(const Primitive& other) const override;
1580 auto state() const {
1581 return std::make_tuple(group_size_, bits_, transpose_);
1582 }
1583
1584 private:
1585 int group_size_;
1586 int bits_;
1587 bool transpose_;
1588};
1589
1591 public:
1592 explicit RandomBits(Stream stream, const Shape& shape, int width)
1593 : UnaryPrimitive(stream), shape_(shape), width_(width) {}
1594
1595 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1596 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1597
1600 bool is_equivalent(const Primitive& other) const override;
1601 std::pair<std::vector<int>, int> state() const {
1602 return {shape_, width_};
1603 };
1604
1605 private:
1606 Shape shape_;
1607 int width_;
1608};
1609
1610class Real : public UnaryPrimitive {
1611 public:
1613
1614 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1615 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1616
1622};
1623
1624class Reshape : public UnaryPrimitive {
1625 public:
1626 explicit Reshape(Stream stream, const Shape& shape)
1627 : UnaryPrimitive(stream), shape_(shape) {}
1628
1629 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1630 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1631
1635 bool is_equivalent(const Primitive& other) const override;
1636 std::vector<int> state() const {
1637 return shape_;
1638 };
1639 static Shape output_shape(const array& input, Shape shape);
1640 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1641
1642 private:
1643 Shape shape_;
1644};
1645
1646class Reduce : public UnaryPrimitive {
1647 public:
1648 enum ReduceType { And, Or, Sum, Prod, Min, Max };
1649
1650 explicit Reduce(
1651 Stream stream,
1652 ReduceType reduce_type,
1653 const std::vector<int>& axes)
1654 : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}
1655
1656 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1657 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1658
1660
1661 std::vector<array> vjp(
1662 const std::vector<array>& primals,
1663 const std::vector<array>& cotangents,
1664 const std::vector<int>& argnums,
1665 const std::vector<array>& outputs) override;
1666
1667 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1668
1669 void print(std::ostream& os) override {
1670 switch (reduce_type_) {
1671 case And:
1672 os << "And";
1673 break;
1674 case Or:
1675 os << "Or";
1676 break;
1677 case Sum:
1678 os << "Sum";
1679 break;
1680 case Prod:
1681 os << "Prod";
1682 break;
1683 case Min:
1684 os << "Min";
1685 break;
1686 case Max:
1687 os << "Max";
1688 break;
1689 }
1690 }
1691 bool is_equivalent(const Primitive& other) const override;
1692 std::pair<ReduceType, std::vector<int>> state() const {
1693 return {reduce_type_, axes_};
1694 };
1695
1696 private:
1697 ReduceType reduce_type_;
1698 std::vector<int> axes_;
1699};
1700
1701class Round : public UnaryPrimitive {
1702 public:
1704
1705 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1706 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1707
1713};
1714
1715class Scan : public UnaryPrimitive {
1716 public:
1718
1719 explicit Scan(
1720 Stream stream,
1721 ReduceType reduce_type,
1722 int axis,
1723 bool reverse,
1724 bool inclusive)
1726 reduce_type_(reduce_type),
1727 axis_(axis),
1728 reverse_(reverse),
1729 inclusive_(inclusive) {}
1730
1731 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1732 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1733
1736
1737 void print(std::ostream& os) override {
1738 os << "Cum";
1739 switch (reduce_type_) {
1740 case Sum:
1741 os << "Sum";
1742 break;
1743 case Prod:
1744 os << "Prod";
1745 break;
1746 case Min:
1747 os << "Min";
1748 break;
1749 case Max:
1750 os << "Max";
1751 break;
1752 }
1753 }
1754 bool is_equivalent(const Primitive& other) const override;
1755 auto state() const {
1756 return std::make_tuple(reduce_type_, axis_, reverse_, inclusive_);
1757 }
1758
1759 private:
1760 ReduceType reduce_type_;
1761 int axis_;
1762 bool reverse_;
1763 bool inclusive_;
1764};
1765
1766class Scatter : public UnaryPrimitive {
1767 public:
1769
1770 explicit Scatter(
1771 Stream stream,
1772 ReduceType reduce_type,
1773 const std::vector<int>& axes)
1774 : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}
1775
1776 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1777 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1778
1781
1782 void print(std::ostream& os) override {
1783 os << "Scatter";
1784 switch (reduce_type_) {
1785 case Sum:
1786 os << " Sum";
1787 break;
1788 case Prod:
1789 os << " Prod";
1790 break;
1791 case Min:
1792 os << " Min";
1793 break;
1794 case Max:
1795 os << " Max";
1796 break;
1797 case None:
1798 break;
1799 }
1800 }
1801 bool is_equivalent(const Primitive& other) const override;
1802 std::pair<ReduceType, std::vector<int>> state() const {
1803 return {reduce_type_, axes_};
1804 };
1805
1806 private:
1807 ReduceType reduce_type_;
1808 std::vector<int> axes_;
1809};
1810
1812 public:
1814
1815 explicit ScatterAxis(Stream stream, ReduceType reduce_type, int axis)
1816 : UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {}
1817
1818 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1819 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1820
1823
1824 void print(std::ostream& os) override {
1825 os << "ScatterAxis";
1826 switch (reduce_type_) {
1827 case Sum:
1828 os << " Sum";
1829 break;
1830 case None:
1831 break;
1832 }
1833 }
1834
1835 bool is_equivalent(const Primitive& other) const override;
1836 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1837 std::pair<ReduceType, int> state() const {
1838 return {reduce_type_, axis_};
1839 }
1840
1841 private:
1842 ReduceType reduce_type_;
1843 int axis_;
1844};
1845
1846class Sigmoid : public UnaryPrimitive {
1847 public:
1849
1850 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1851 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1852
1858};
1859
1860class Sign : public UnaryPrimitive {
1861 public:
1863
1864 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1865 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1866
1872};
1873
1874class Sin : public UnaryPrimitive {
1875 public:
1877
1878 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1879 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1880
1886};
1887
1888class Sinh : public UnaryPrimitive {
1889 public:
1891
1892 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1893 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1894
1900};
1901
1902class Slice : public UnaryPrimitive {
1903 public:
1904 explicit Slice(
1905 Stream stream,
1906 const Shape& start_indices,
1907 const Shape& end_indices,
1908 const Shape& strides)
1910 start_indices_(start_indices),
1911 end_indices_(end_indices),
1912 strides_(strides) {}
1913
1914 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1915 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1916
1920 bool is_equivalent(const Primitive& other) const override;
1921 auto state() const {
1922 return std::make_tuple(start_indices_, end_indices_, strides_);
1923 }
1924
1925 private:
1926 Shape start_indices_;
1927 Shape end_indices_;
1928 Shape strides_;
1929};
1930
1932 public:
1933 explicit SliceUpdate(
1934 Stream stream,
1935 const Shape& start_indices,
1936 const Shape& end_indices,
1937 const Shape& strides)
1939 start_indices_(start_indices),
1940 end_indices_(end_indices),
1941 strides_(strides) {}
1942
1943 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1944 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1945
1949 bool is_equivalent(const Primitive& other) const override;
1951 auto state() const {
1952 return std::make_tuple(start_indices_, end_indices_, strides_);
1953 }
1954
1955 private:
1956 Shape start_indices_;
1957 Shape end_indices_;
1958 Shape strides_;
1959};
1960
1962 public:
1963 explicit DynamicSlice(Stream stream, std::vector<int> axes, Shape slice_size)
1965 axes_(std::move(axes)),
1966 slice_size_(std::move(slice_size)) {}
1967
1968 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1969 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1970
1974 bool is_equivalent(const Primitive& other) const override;
1975 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1976 auto state() const {
1977 return std::make_pair(axes_, slice_size_);
1978 }
1979
1980 private:
1981 std::vector<int> axes_;
1982 Shape slice_size_;
1983};
1984
1986 public:
1987 explicit DynamicSliceUpdate(Stream stream, std::vector<int> axes)
1988 : UnaryPrimitive(stream), axes_(std::move(axes)) {}
1989
1990 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1991 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1992
1996 bool is_equivalent(const Primitive& other) const override;
1998 auto state() const {
1999 return axes_;
2000 }
2001
2002 private:
2003 std::vector<int> axes_;
2004};
2005
2006class Softmax : public UnaryPrimitive {
2007 public:
2008 explicit Softmax(Stream stream, bool precise)
2009 : UnaryPrimitive(stream), precise_(precise) {}
2010
2011 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2012 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2013
2018
2019 bool is_equivalent(const Primitive& other) const override;
2020 auto state() const {
2021 return precise_;
2022 };
2023
2024 private:
2025 bool precise_;
2026};
2027
2028class Sort : public UnaryPrimitive {
2029 public:
2030 explicit Sort(Stream stream, int axis)
2031 : UnaryPrimitive(stream), axis_(axis) {}
2032
2033 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2034 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2035
2040 bool is_equivalent(const Primitive& other) const override;
2041 auto state() const {
2042 return axis_;
2043 }
2044
2045 private:
2046 int axis_;
2047};
2048
2049class Split : public Primitive {
2050 public:
2051 explicit Split(Stream stream, const Shape& indices, int axis)
2052 : Primitive(stream), indices_(indices), axis_(axis) {}
2053
2054 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2055 override;
2056 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2057 override;
2058
2062 bool is_equivalent(const Primitive& other) const override;
2063 std::pair<std::vector<int>, int> state() const {
2064 return {indices_, axis_};
2065 };
2066
2067 private:
2068 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
2069
2070 Shape indices_;
2071 int axis_;
2072};
2073
2074class Square : public UnaryPrimitive {
2075 public:
2077
2078 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2079 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2080
2086};
2087
2088class Sqrt : public UnaryPrimitive {
2089 public:
2090 explicit Sqrt(Stream stream, bool recip = false)
2091 : UnaryPrimitive(stream), recip_(recip) {}
2092
2093 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2094 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2095
2099 bool is_equivalent(const Primitive& other) const override;
2100 auto state() const {
2101 return recip_;
2102 }
2103
2104 void print(std::ostream& os) override {
2105 if (recip_) {
2106 os << "Rsqrt";
2107 } else {
2108 os << "Sqrt";
2109 }
2110 }
2111
2112 private:
2113 bool recip_;
2114};
2115
2117 public:
2119
2120 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2121 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2122
2127
2128 private:
2129 void eval(const std::vector<array>& inputs, array& out);
2130};
2131
2132class Subtract : public UnaryPrimitive {
2133 public:
2135
2136 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2137 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2138
2144};
2145
2146class Squeeze : public UnaryPrimitive {
2147 public:
2148 explicit Squeeze(Stream stream, std::vector<int> axes)
2149 : UnaryPrimitive(stream), axes_(std::move(axes)) {}
2150
2151 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2152 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2153
2157
2158 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
2159 bool is_equivalent(const Primitive& other) const override;
2160
2161 static Shape output_shape(const array& input, const std::vector<int>& axes);
2162 auto state() const {
2163 return axes_;
2164 };
2165
2166 private:
2167 void eval(const std::vector<array>& inputs, array& out);
2168 std::vector<int> axes_;
2169};
2170
2171class Tan : public UnaryPrimitive {
2172 public:
2174
2175 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2176 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2177
2183};
2184
2185class Tanh : public UnaryPrimitive {
2186 public:
2188
2189 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2190 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2191
2197};
2198
2200 public:
2201 explicit Unflatten(Stream stream, int axis, Shape shape)
2202 : UnaryPrimitive(stream), axis_(axis), shape_(std::move(shape)) {}
2203
2204 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2205 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2206
2210 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
2211 bool is_equivalent(const Primitive& other) const override;
2212
2213 static Shape output_shape(const array& input, int axis, const Shape& shape);
2214 auto state() const {
2215 return std::make_pair(axis_, shape_);
2216 }
2217
2218 private:
2219 int axis_;
2220 Shape shape_;
2221 void eval(const std::vector<array>& inputs, array& out);
2222};
2223
2224class View : public UnaryPrimitive {
2225 public:
2226 explicit View(Stream stream, Dtype dtype)
2227 : UnaryPrimitive(stream), dtype_(dtype) {}
2228
2229 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2230 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2231
2233 void print(std::ostream& os) override;
2234 bool is_equivalent(const Primitive& other) const override;
2235 auto state() const {
2236 return dtype_;
2237 }
2238
2239 private:
2240 Dtype dtype_;
2241};
2242
2244 public:
2245 explicit Transpose(Stream stream, const std::vector<int>& axes)
2246 : UnaryPrimitive(stream), axes_(axes) {}
2247
2248 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2249 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2250
2254 bool is_equivalent(const Primitive& other) const override;
2255 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
2256 std::vector<int> state() const {
2257 return axes_;
2258 };
2259
2260 private:
2261 std::vector<int> axes_;
2262
2263 void eval(const std::vector<array>& inputs, array& out);
2264};
2265
2266/* QR Factorization primitive. */
2267class QRF : public Primitive {
2268 public:
2270
2271 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2272 override;
2273 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2274 override;
2275
2277};
2278
2279/* SVD primitive. */
2280class SVD : public Primitive {
2281 public:
2282 explicit SVD(Stream stream, bool compute_uv)
2283 : Primitive(stream), compute_uv_(compute_uv) {}
2284
2285 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2286 override;
2287 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2288 override;
2289
2292 auto state() const {
2293 return compute_uv_;
2294 }
2295
2296 private:
2297 bool compute_uv_;
2298};
2299
2300/* Matrix inversion primitive. */
2301class Inverse : public UnaryPrimitive {
2302 public:
2303 explicit Inverse(Stream stream, bool tri, bool upper)
2304 : UnaryPrimitive(stream), tri_(tri), upper_(upper) {}
2305
2306 void eval_cpu(const std::vector<array>& inputs, array& output) override;
2307 void eval_gpu(const std::vector<array>& inputs, array& output) override;
2308
2311 auto state() const {
2312 return std::make_pair(tri_, upper_);
2313 }
2314
2315 private:
2316 bool tri_;
2317 bool upper_;
2318};
2319
2320class Cholesky : public UnaryPrimitive {
2321 public:
2322 explicit Cholesky(Stream stream, bool upper)
2323 : UnaryPrimitive(stream), upper_(upper) {}
2324
2325 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2326 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2327 auto state() const {
2328 return upper_;
2329 }
2330
2333
2334 private:
2335 bool upper_;
2336};
2337
2338class Eigh : public Primitive {
2339 public:
2340 explicit Eigh(Stream stream, std::string uplo, bool compute_eigenvectors)
2341 : Primitive(stream),
2342 uplo_(std::move(uplo)),
2343 compute_eigenvectors_(compute_eigenvectors) {}
2344 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2345 override;
2346 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2347 override;
2348
2351
2352 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
2353
2354 bool is_equivalent(const Primitive& other) const override;
2355 auto state() const {
2356 return std::make_pair(uplo_, compute_eigenvectors_);
2357 }
2358
2359 private:
2360 std::string uplo_;
2361 bool compute_eigenvectors_;
2362};
2363
2364/* LU Factorization primitive. */
2365class LUF : public Primitive {
2366 public:
2368 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2369 override;
2370 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2371 override;
2372
2374};
2375
2376} // namespace mlx::core
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Abs(Stream stream)
Definition primitives.h:156
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Add(Stream stream)
Definition primitives.h:170
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::pair< float, float > state() const
Definition primitives.h:195
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
AddMM(Stream stream, float alpha, float beta)
Definition primitives.h:184
Arange(Stream stream, double start, double stop, double step)
Definition primitives.h:206
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::tuple< double, double, double > state() const
Definition primitives.h:215
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCos(Stream stream)
Definition primitives.h:227
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCosh(Stream stream)
Definition primitives.h:241
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcSin(Stream stream)
Definition primitives.h:255
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcSinh(Stream stream)
Definition primitives.h:269
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTan2(Stream stream)
Definition primitives.h:297
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcTan(Stream stream)
Definition primitives.h:283
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTanh(Stream stream)
Definition primitives.h:311
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::pair< int, int > state() const
Definition primitives.h:336
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArgPartition(Stream stream, int kth, int axis)
Definition primitives.h:325
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
ReduceType
Definition primitives.h:347
@ ArgMin
Definition primitives.h:348
@ ArgMax
Definition primitives.h:349
ArgReduce(Stream stream, ReduceType reduce_type, int axis)
Definition primitives.h:352
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::pair< ReduceType, int > state() const
Definition primitives.h:363
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
ArgSort(Stream stream, int axis)
Definition primitives.h:374
int state() const
Definition primitives.h:384
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:427
AsStrided(Stream stream, Shape shape, Strides strides, size_t offset)
Definition primitives.h:415
void eval_gpu(const std::vector< array > &inputs, array &out) override
AsType(Stream stream, Dtype dtype)
Definition primitives.h:394
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Dtype state() const
Definition primitives.h:405
void eval_cpu(const std::vector< array > &inputs, array &out) override
BitwiseBinary(Stream stream, Op op)
Definition primitives.h:443
void eval_cpu(const std::vector< array > &inputs, array &out) override
void print(std::ostream &os) override
Print the primitive.
Op
Definition primitives.h:441
@ RightShift
Definition primitives.h:441
@ Or
Definition primitives.h:441
@ LeftShift
Definition primitives.h:441
@ And
Definition primitives.h:441
@ Xor
Definition primitives.h:441
auto state() const
Definition primitives.h:454
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
BitwiseInvert(Stream stream)
Definition primitives.h:464
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
auto state() const
Definition primitives.h:491
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
BlockMaskedMM(Stream stream, int block_size)
Definition primitives.h:477
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
BroadcastAxes(Stream stream, std::vector< int > ignore_axes={})
Definition primitives.h:518
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:532
void eval_cpu(const std::vector< array > &inputs, array &out) override
static Shape output_shape(const std::vector< array > &inputs, const std::vector< int > &ignore_axes)
Broadcast(Stream stream, const Shape &shape)
Definition primitives.h:543
static Shape output_shape(const std::vector< array > &inputs)
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< int > state() const
Definition primitives.h:555
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Ceil(Stream stream)
Definition primitives.h:567
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2327
Cholesky(Stream stream, bool upper)
Definition primitives.h:2322
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void print(std::ostream &os) override
Print the primitive.
Compiled(Stream stream, std::vector< array > inputs, std::vector< array > outputs, std::vector< array > tape, std::unordered_set< uintptr_t > constant_ids)
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
std::string lib_name() const
Definition primitives.h:608
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:634
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Concatenate(Stream stream, int axis)
Definition primitives.h:623
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Conjugate(Stream stream)
Definition primitives.h:644
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Contiguous(Stream stream, bool allow_col_major)
Definition primitives.h:657
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
Convolution(Stream stream, const std::vector< int > &kernel_strides, const std::vector< int > &padding, const std::vector< int > &kernel_dilation, const std::vector< int > &input_dilation, const int groups=1, const bool flip=false)
Definition primitives.h:676
auto state() const
Definition primitives.h:703
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
Copy(Stream stream)
Definition primitives.h:724
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Cos(Stream stream)
Definition primitives.h:741
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Cosh(Stream stream)
Definition primitives.h:755
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
virtual std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes) override
The primitive must know how to vectorize itself across the given axes.
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
The Jacobian-vector product.
CustomTransforms(Stream stream, int num_outputs, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> vjp, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< int > &)> jvp, std::function< std::pair< std::vector< array >, std::vector< int > >(const std::vector< array > &, const std::vector< int > &)> vmap)
Definition primitives.h:769
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotan, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Depends(Stream stream)
Definition primitives.h:821
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:867
DivMod(Stream stream)
Definition primitives.h:856
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Divide(Stream stream)
Definition primitives.h:842
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
DynamicSlice(Stream stream, std::vector< int > axes, Shape slice_size)
Definition primitives.h:1963
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1976
auto state() const
Definition primitives.h:1998
DynamicSliceUpdate(Stream stream, std::vector< int > axes)
Definition primitives.h:1987
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
auto state() const
Definition primitives.h:2355
Eigh(Stream stream, std::string uplo, bool compute_eigenvectors)
Definition primitives.h:2340
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:913
Equal(Stream stream, bool equal_nan=false)
Definition primitives.h:902
auto state() const
Definition primitives.h:920
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Erf(Stream stream)
Definition primitives.h:930
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ErfInv(Stream stream)
Definition primitives.h:944
void eval_cpu(const std::vector< array > &inputs, array &out) override
Exp(Stream stream)
Definition primitives.h:958
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
static Shape output_shape(const array &input, const std::vector< int > &axes)
auto state() const
Definition primitives.h:999
void eval_gpu(const std::vector< array > &inputs, array &out) override
ExpandDims(Stream stream, std::vector< int > axes)
Definition primitives.h:985
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Expm1(Stream stream)
Definition primitives.h:972
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
FFT(Stream stream, const std::vector< size_t > &axes, bool inverse, bool real)
Definition primitives.h:1010
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1025
static Shape output_shape(const array &input, int start_axis, int end_axis)
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
Flatten(Stream stream, int start_axis, int end_axis)
Definition primitives.h:1037
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1050
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Floor(Stream stream)
Definition primitives.h:1062
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Full(Stream stream)
Definition primitives.h:1076
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
GatherAxis(Stream stream, int axis)
Definition primitives.h:1113
auto state() const
Definition primitives.h:1124
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::pair< std::vector< int >, std::vector< int > > state() const
Definition primitives.h:1102
Gather(Stream stream, std::vector< int > axes, Shape slice_sizes)
Definition primitives.h:1089
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, array &out) override
GatherMM(Stream stream)
Definition primitives.h:501
auto state() const
Definition primitives.h:1580
GatherQMM(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1567
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
GreaterEqual(Stream stream)
Definition primitives.h:1148
void eval_gpu(const std::vector< array > &inputs, array &out) override
Greater(Stream stream)
Definition primitives.h:1134
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
Hadamard(Stream stream, float scale)
Definition primitives.h:1162
auto state() const
Definition primitives.h:1174
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Imag(Stream stream)
Definition primitives.h:1184
void eval_gpu(const std::vector< array > &inputs, array &output) override
Inverse(Stream stream, bool tri, bool upper)
Definition primitives.h:2303
auto state() const
Definition primitives.h:2311
void eval_cpu(const std::vector< array > &inputs, array &output) override
LUF(Stream stream)
Definition primitives.h:2367
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
LessEqual(Stream stream)
Definition primitives.h:1212
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Less(Stream stream)
Definition primitives.h:1198
void eval_gpu(const std::vector< array > &inputs, array &out) override
Load(Stream stream, std::shared_ptr< io::Reader > reader, size_t offset, bool swap_endianness=false)
Definition primitives.h:1226
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Log1p(Stream stream)
Definition primitives.h:1286
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogAddExp(Stream stream)
Definition primitives.h:1341
Base
Definition primitives.h:1249
@ ten
Definition primitives.h:1249
@ two
Definition primitives.h:1249
@ e
Definition primitives.h:1249
Log(Stream stream, Base base)
Definition primitives.h:1251
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1266
Base state() const
Definition primitives.h:1262
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalAnd(Stream stream)
Definition primitives.h:1313
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalNot(Stream stream)
Definition primitives.h:1299
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
LogicalOr(Stream stream)
Definition primitives.h:1327
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Matmul(Stream stream)
Definition primitives.h:1355
Maximum(Stream stream)
Definition primitives.h:1369
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Minimum(Stream stream)
Definition primitives.h:1383
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Multiply(Stream stream)
Definition primitives.h:1397
void eval_gpu(const std::vector< array > &inputs, array &out) override
Negative(Stream stream)
Definition primitives.h:1411
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
NotEqual(Stream stream)
Definition primitives.h:1425
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:1455
NumberOfElements(Stream stream, std::vector< int > axes, bool inverted, Dtype dtype)
Definition primitives.h:1439
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::tuple< std::vector< int >, bool, Dtype > state() const
Definition primitives.h:1458
auto state() const
Definition primitives.h:1489
Pad(Stream stream, const std::vector< int > &axes, const Shape &low_pad_size, const Shape &high_pad_size)
Definition primitives.h:1472
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Partition(Stream stream, int kth, int axis)
Definition primitives.h:1501
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
auto state() const
Definition primitives.h:1512
void eval_cpu(const std::vector< array > &inputs, array &out) override
Power(Stream stream)
Definition primitives.h:1523
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:48
virtual void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
virtual std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
The vector-Jacobian product.
virtual ~Primitive()=default
Primitive(const Primitive &other)=delete
Primitive(Primitive &&other)=delete
const Stream & stream()
The stream the primitive will run on.
Definition primitives.h:58
Primitive & operator=(Primitive &&other)=delete
virtual bool is_equivalent(const Primitive &other) const
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:107
Primitive & operator=(const Primitive &other)=delete
const Device & device()
The device the primitive will run on.
Definition primitives.h:53
virtual std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
The Jacobian-vector product.
virtual std::vector< Shape > output_shapes(const std::vector< array > &inputs)
Get the output shapes of the primitive.
virtual std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes)
The primitive must know how to vectorize itself across the given axes.
virtual void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
virtual void print(std::ostream &os)=0
Print the primitive.
Primitive(Stream stream)
Definition primitives.h:50
QRF(Stream stream)
Definition primitives.h:2269
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
QuantizedMatmul(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1537
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1555
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::pair< std::vector< int >, int > state() const
Definition primitives.h:1601
RandomBits(Stream stream, const Shape &shape, int width)
Definition primitives.h:1592
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Real(Stream stream)
Definition primitives.h:1612
Reduce(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1650
ReduceType
Definition primitives.h:1648
@ Min
Definition primitives.h:1648
@ Or
Definition primitives.h:1648
@ Max
Definition primitives.h:1648
@ And
Definition primitives.h:1648
@ Sum
Definition primitives.h:1648
@ Prod
Definition primitives.h:1648
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1669
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::pair< ReduceType, std::vector< int > > state() const
Definition primitives.h:1692
Remainder(Stream stream)
Definition primitives.h:888
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
static Shape output_shape(const array &input, Shape shape)
void eval_gpu(const std::vector< array > &inputs, array &out) override
Reshape(Stream stream, const Shape &shape)
Definition primitives.h:1626
std::vector< int > state() const
Definition primitives.h:1636
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Round(Stream stream)
Definition primitives.h:1703
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
SVD(Stream stream, bool compute_uv)
Definition primitives.h:2282
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
auto state() const
Definition primitives.h:2292
void eval_cpu(const std::vector< array > &inputs, array &out) override
ReduceType
Definition primitives.h:1717
@ Prod
Definition primitives.h:1717
@ Min
Definition primitives.h:1717
@ Max
Definition primitives.h:1717
@ Sum
Definition primitives.h:1717
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
auto state() const
Definition primitives.h:1755
Scan(Stream stream, ReduceType reduce_type, int axis, bool reverse, bool inclusive)
Definition primitives.h:1719
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1737
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::pair< ReduceType, int > state() const
Definition primitives.h:1837
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1824
void eval_gpu(const std::vector< array > &inputs, array &out) override
ScatterAxis(Stream stream, ReduceType reduce_type, int axis)
Definition primitives.h:1815
ReduceType
Definition primitives.h:1813
@ Sum
Definition primitives.h:1813
@ None
Definition primitives.h:1813
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::pair< ReduceType, std::vector< int > > state() const
Definition primitives.h:1802
ReduceType
Definition primitives.h:1768
@ Sum
Definition primitives.h:1768
@ Max
Definition primitives.h:1768
@ Prod
Definition primitives.h:1768
@ None
Definition primitives.h:1768
@ Min
Definition primitives.h:1768
void eval_cpu(const std::vector< array > &inputs, array &out) override
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1782
void eval_gpu(const std::vector< array > &inputs, array &out) override
Scatter(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1770
void eval_gpu(const std::vector< array > &inputs, array &out) override
Select(Stream stream)
Definition primitives.h:874
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sigmoid(Stream stream)
Definition primitives.h:1848
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Sign(Stream stream)
Definition primitives.h:1862
Sin(Stream stream)
Definition primitives.h:1876
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sinh(Stream stream)
Definition primitives.h:1890
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1921
Slice(Stream stream, const Shape &start_indices, const Shape &end_indices, const Shape &strides)
Definition primitives.h:1904
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
SliceUpdate(Stream stream, const Shape &start_indices, const Shape &end_indices, const Shape &strides)
Definition primitives.h:1933
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1951
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Softmax(Stream stream, bool precise)
Definition primitives.h:2008
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2020
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2041
Sort(Stream stream, int axis)
Definition primitives.h:2030
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
std::pair< std::vector< int >, int > state() const
Definition primitives.h:2063
Split(Stream stream, const Shape &indices, int axis)
Definition primitives.h:2051
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
auto state() const
Definition primitives.h:2100
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sqrt(Stream stream, bool recip=false)
Definition primitives.h:2090
void eval_gpu(const std::vector< array > &inputs, array &out) override
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:2104
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Square(Stream stream)
Definition primitives.h:2076
Squeeze(Stream stream, std::vector< int > axes)
Definition primitives.h:2148
auto state() const
Definition primitives.h:2162
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
static Shape output_shape(const array &input, const std::vector< int > &axes)
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
StopGradient(Stream stream)
Definition primitives.h:2118
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Subtract(Stream stream)
Definition primitives.h:2134
Tan(Stream stream)
Definition primitives.h:2173
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Tanh(Stream stream)
Definition primitives.h:2187
void eval_cpu(const std::vector< array > &inputs, array &out) override
Transpose(Stream stream, const std::vector< int > &axes)
Definition primitives.h:2245
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< int > state() const
Definition primitives.h:2256
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:126
UnaryPrimitive & operator=(const UnaryPrimitive &other)=delete
UnaryPrimitive(Stream stream)
An abstract base class for a primitive with a single output.
Definition primitives.h:131
virtual void eval_gpu(const std::vector< array > &inputs, array &output)=0
UnaryPrimitive(UnaryPrimitive &&other)=delete
virtual void eval_cpu(const std::vector< array > &inputs, array &output)=0
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:141
UnaryPrimitive(const UnaryPrimitive &other)=delete
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:136
UnaryPrimitive & operator=(UnaryPrimitive &&other)=delete
virtual ~UnaryPrimitive()=default
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Unflatten(Stream stream, int axis, Shape shape)
Definition primitives.h:2201
static Shape output_shape(const array &input, int axis, const Shape &shape)
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2214
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2235
void print(std::ostream &os) override
Print the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
View(Stream stream, Dtype dtype)
Definition primitives.h:2226
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition array.h:24
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
array tri(int n, int m, int k, Dtype type, StreamOrDevice s={})
array transpose(const array &a, std::vector< int > axes, StreamOrDevice s={})
Permutes the dimensions according to the given axes.
array real(const array &a, StreamOrDevice s={})
Definition allocator.h:7
std::vector< ShapeElem > Shape
Definition array.h:21
std::vector< int64_t > Strides
Definition array.h:22
void eval(std::vector< array > outputs)
#define DEFINE_DEFAULT_IS_EQUIVALENT()
Definition primitives.h:34
#define DEFINE_PRINT(PRIMITIVE)
Definition primitives.h:29
#define DEFINE_INPUT_OUTPUT_SHAPE()
Definition primitives.h:39
#define DEFINE_GRADS()
Definition primitives.h:17
#define DEFINE_VMAP()
Definition primitives.h:12
Definition device.h:7
Definition dtype.h:13
Definition stream.h:9
Device device
Definition stream.h:11