MLX
 
Loading...
Searching...
No Matches
primitives.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <unordered_set>
6
7#include "mlx/array.h"
8#include "mlx/device.h"
9#include "mlx/io/load.h"
10#include "mlx/stream.h"
11
12#define DEFINE_VMAP() \
13 virtual std::pair<std::vector<array>, std::vector<int>> vmap( \
14 const std::vector<array>& inputs, const std::vector<int>& axes) \
15 override;
16
17#define DEFINE_GRADS() \
18 std::vector<array> jvp( \
19 const std::vector<array>& primals, \
20 const std::vector<array>& tangents, \
21 const std::vector<int>& argnums) override; \
22 \
23 std::vector<array> vjp( \
24 const std::vector<array>& primals, \
25 const std::vector<array>& cotangents, \
26 const std::vector<int>& argnums, \
27 const std::vector<array>& outputs) override;
28
29#define DEFINE_PRINT(PRIMITIVE) \
30 void print(std::ostream& os) override { \
31 os << #PRIMITIVE; \
32 }
33
34#define DEFINE_DEFAULT_IS_EQUIVALENT() \
35 bool is_equivalent(const Primitive& other) const override { \
36 return true; \
37 }
38
39#define DEFINE_INPUT_OUTPUT_SHAPE() \
40 std::vector<Shape> output_shapes(const std::vector<array>& inputs) \
41 override { \
42 return {inputs[0].shape()}; \
43 }
44
45namespace mlx::core {
46
47// Abstract base class
48class Primitive {
49 public:
50 explicit Primitive(Stream stream) : stream_(stream) {}
51
53 const Device& device() {
54 return stream().device;
55 }
56
58 const Stream& stream() {
59 return stream_;
60 }
61
69 virtual void eval_cpu(
70 const std::vector<array>& inputs,
71 std::vector<array>& outputs) = 0;
72 virtual void eval_gpu(
73 const std::vector<array>& inputs,
74 std::vector<array>& outputs) = 0;
75
79 virtual std::vector<array> jvp(
80 const std::vector<array>& primals,
81 const std::vector<array>& tangents,
82 const std::vector<int>& argnums);
83
87 virtual std::vector<array> vjp(
88 const std::vector<array>& primals,
89 const std::vector<array>& cotangents,
90 const std::vector<int>& argnums,
91 const std::vector<array>& outputs);
92
99 virtual std::pair<std::vector<array>, std::vector<int>> vmap(
100 const std::vector<array>& inputs,
101 const std::vector<int>& axes);
102
104 virtual void print(std::ostream& os) = 0;
105
107 virtual bool is_equivalent(const Primitive& other) const {
108 return false;
109 }
110
113 virtual std::vector<Shape> output_shapes(const std::vector<array>& inputs);
114
115 virtual ~Primitive() = default;
116 Primitive(const Primitive& other) = delete;
117 Primitive(Primitive&& other) = delete;
118 Primitive& operator=(const Primitive& other) = delete;
119 Primitive& operator=(Primitive&& other) = delete;
120
121 private:
122 // Every primitive stores the stream it should run in
123 Stream stream_;
124};
125
126class UnaryPrimitive : public Primitive {
130 public:
132
133 virtual void eval_cpu(const std::vector<array>& inputs, array& output) = 0;
134 virtual void eval_gpu(const std::vector<array>& inputs, array& output) = 0;
135
136 inline void eval_cpu(
137 const std::vector<array>& inputs,
138 std::vector<array>& outputs) override {
139 eval_cpu(inputs, outputs[0]);
140 }
141 inline void eval_gpu(
142 const std::vector<array>& inputs,
143 std::vector<array>& outputs) override {
144 eval_gpu(inputs, outputs[0]);
145 }
146
147 virtual ~UnaryPrimitive() = default;
148 UnaryPrimitive(const UnaryPrimitive& other) = delete;
150 UnaryPrimitive& operator=(const UnaryPrimitive& other) = delete;
152};
153
154class Abs : public UnaryPrimitive {
155 public:
157
158 void eval_cpu(const std::vector<array>& inputs, array& out) override;
159 void eval_gpu(const std::vector<array>& inputs, array& out) override;
160
166};
167
168class Add : public UnaryPrimitive {
169 public:
171
172 void eval_cpu(const std::vector<array>& inputs, array& out) override;
173 void eval_gpu(const std::vector<array>& inputs, array& out) override;
174
180};
181
182class AddMM : public UnaryPrimitive {
183 public:
184 explicit AddMM(Stream stream, float alpha, float beta)
185 : UnaryPrimitive(stream), alpha_(alpha), beta_(beta) {}
186
187 void eval_cpu(const std::vector<array>& inputs, array& out) override;
188 void eval_gpu(const std::vector<array>& inputs, array& out) override;
189
193
194 bool is_equivalent(const Primitive& other) const override;
195 std::pair<float, float> state() const {
196 return {alpha_, beta_};
197 };
198
199 private:
200 const float alpha_;
201 const float beta_;
202};
203
204class Arange : public UnaryPrimitive {
205 public:
206 explicit Arange(Stream stream, double start, double stop, double step)
207 : UnaryPrimitive(stream), start_(start), stop_(stop), step_(step) {}
208
209 void eval_cpu(const std::vector<array>& inputs, array& out) override;
210 void eval_gpu(const std::vector<array>& inputs, array& out) override;
211
213 bool is_equivalent(const Primitive& other) const override;
214 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
215 std::tuple<double, double, double> state() const {
216 return {start_, stop_, step_};
217 };
218
219 private:
220 double start_;
221 double stop_;
222 double step_;
223};
224
225class ArcCos : public UnaryPrimitive {
226 public:
228
229 void eval_cpu(const std::vector<array>& inputs, array& out) override;
230 void eval_gpu(const std::vector<array>& inputs, array& out) override;
231
237};
238
239class ArcCosh : public UnaryPrimitive {
240 public:
242
243 void eval_cpu(const std::vector<array>& inputs, array& out) override;
244 void eval_gpu(const std::vector<array>& inputs, array& out) override;
245
251};
252
253class ArcSin : public UnaryPrimitive {
254 public:
256
257 void eval_cpu(const std::vector<array>& inputs, array& out) override;
258 void eval_gpu(const std::vector<array>& inputs, array& out) override;
259
265};
266
267class ArcSinh : public UnaryPrimitive {
268 public:
270
271 void eval_cpu(const std::vector<array>& inputs, array& out) override;
272 void eval_gpu(const std::vector<array>& inputs, array& out) override;
273
279};
280
281class ArcTan : public UnaryPrimitive {
282 public:
284
285 void eval_cpu(const std::vector<array>& inputs, array& out) override;
286 void eval_gpu(const std::vector<array>& inputs, array& out) override;
287
293};
294
295class ArcTan2 : public UnaryPrimitive {
296 public:
298
299 void eval_cpu(const std::vector<array>& inputs, array& out) override;
300 void eval_gpu(const std::vector<array>& inputs, array& out) override;
301
307};
308
309class ArcTanh : public UnaryPrimitive {
310 public:
312
313 void eval_cpu(const std::vector<array>& inputs, array& out) override;
314 void eval_gpu(const std::vector<array>& inputs, array& out) override;
315
321};
322
324 public:
325 explicit ArgPartition(Stream stream, int kth, int axis)
326 : UnaryPrimitive(stream), kth_(kth), axis_(axis) {}
327
328 void eval_cpu(const std::vector<array>& inputs, array& out) override;
329 void eval_gpu(const std::vector<array>& inputs, array& out) override;
330
335 bool is_equivalent(const Primitive& other) const override;
336 std::pair<int, int> state() const {
337 return {kth_, axis_};
338 };
339
340 private:
341 int kth_;
342 int axis_;
343};
344
345class ArgReduce : public UnaryPrimitive {
346 public:
351
352 explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis)
353 : UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {}
354
355 void eval_cpu(const std::vector<array>& inputs, array& out) override;
356 void eval_gpu(const std::vector<array>& inputs, array& out) override;
357
361 bool is_equivalent(const Primitive& other) const override;
362 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
363 std::pair<ReduceType, int> state() const {
364 return {reduce_type_, axis_};
365 };
366
367 private:
368 ReduceType reduce_type_;
369 int axis_;
370};
371
372class ArgSort : public UnaryPrimitive {
373 public:
374 explicit ArgSort(Stream stream, int axis)
375 : UnaryPrimitive(stream), axis_(axis) {}
376
377 void eval_cpu(const std::vector<array>& inputs, array& out) override;
378 void eval_gpu(const std::vector<array>& inputs, array& out) override;
379
383 bool is_equivalent(const Primitive& other) const override;
384 int state() const {
385 return axis_;
386 };
387
388 private:
389 int axis_;
390};
391
392class AsType : public UnaryPrimitive {
393 public:
394 explicit AsType(Stream stream, Dtype dtype)
395 : UnaryPrimitive(stream), dtype_(dtype) {}
396
397 void eval_cpu(const std::vector<array>& inputs, array& out) override;
398 void eval_gpu(const std::vector<array>& inputs, array& out) override;
399
404 bool is_equivalent(const Primitive& other) const override;
405 Dtype state() const {
406 return dtype_;
407 };
408
409 private:
410 Dtype dtype_;
411};
412
413class AsStrided : public UnaryPrimitive {
414 public:
415 explicit AsStrided(Stream stream, Shape shape, Strides strides, size_t offset)
417 shape_(std::move(shape)),
418 strides_(std::move(strides)),
419 offset_(offset) {}
420
421 void eval_cpu(const std::vector<array>& inputs, array& out) override;
422 void eval_gpu(const std::vector<array>& inputs, array& out) override;
423
426 bool is_equivalent(const Primitive& other) const override;
427 auto state() const {
428 return std::make_tuple(shape_, strides_, offset_);
429 }
430
431 private:
432 Shape shape_;
433 Strides strides_;
434 size_t offset_;
435
436 void eval(const std::vector<array>& inputs, array& out);
437};
438
440 public:
442
444 : UnaryPrimitive(stream), op_(op) {}
445
446 void eval_cpu(const std::vector<array>& inputs, array& out) override;
447 void eval_gpu(const std::vector<array>& inputs, array& out) override;
448
451 bool is_equivalent(const Primitive& other) const override;
452 void print(std::ostream& os) override;
454 auto state() const {
455 return op_;
456 }
457
458 private:
459 Op op_;
460};
461
463 public:
465
466 void eval_cpu(const std::vector<array>& inputs, array& out) override;
467 void eval_gpu(const std::vector<array>& inputs, array& out) override;
468
473};
474
476 public:
477 explicit BlockMaskedMM(Stream stream, int block_size)
478 : UnaryPrimitive(stream), block_size_(block_size) {}
479
480 void eval_cpu(const std::vector<array>& inputs, array& out) override;
481 void eval_gpu(const std::vector<array>& inputs, array& out) override;
482
483 std::vector<array> vjp(
484 const std::vector<array>& primals,
485 const std::vector<array>& cotangents,
486 const std::vector<int>& argnums,
487 const std::vector<array>& outputs) override;
488
490 bool is_equivalent(const Primitive& other) const override;
491 auto state() const {
492 return block_size_;
493 }
494
495 private:
496 int block_size_;
497};
498
499class GatherMM : public UnaryPrimitive {
500 public:
502
503 void eval_cpu(const std::vector<array>& inputs, array& out) override;
504 void eval_gpu(const std::vector<array>& inputs, array& out) override;
505
506 std::vector<array> vjp(
507 const std::vector<array>& primals,
508 const std::vector<array>& cotangents,
509 const std::vector<int>& argnums,
510 const std::vector<array>& outputs) override;
511
514};
515
517 public:
518 explicit BroadcastAxes(Stream stream, std::vector<int> ignore_axes = {})
519 : UnaryPrimitive(stream), ignore_axes_(std::move(ignore_axes)) {}
520
521 void eval_cpu(const std::vector<array>& inputs, array& out) override;
522 void eval_gpu(const std::vector<array>& inputs, array& out) override;
523
527 bool is_equivalent(const Primitive& other) const override;
529 const std::vector<array>& inputs,
530 const std::vector<int>& ignore_axes);
531 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
532 auto state() const {
533 return ignore_axes_;
534 }
535
536 private:
537 void eval(const std::vector<array>& inputs, array& out);
538 std::vector<int> ignore_axes_;
539};
540
541class Broadcast : public UnaryPrimitive {
542 public:
543 explicit Broadcast(Stream stream, const Shape& shape)
544 : UnaryPrimitive(stream), shape_(shape) {}
545
546 void eval_cpu(const std::vector<array>& inputs, array& out) override;
547 void eval_gpu(const std::vector<array>& inputs, array& out) override;
548
552 static Shape output_shape(const std::vector<array>& inputs);
553 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
554 bool is_equivalent(const Primitive& other) const override;
555 std::vector<int> state() const {
556 return shape_;
557 };
558
559 private:
560 Shape shape_;
561
562 void eval(const std::vector<array>& inputs, array& out);
563};
564
565class Ceil : public UnaryPrimitive {
566 public:
568
569 void eval_cpu(const std::vector<array>& inputs, array& out) override;
570 void eval_gpu(const std::vector<array>& inputs, array& out) override;
571
577};
578
579class Compiled : public Primitive {
580 public:
581 /*
582 * The inputs, outputs and tape are either tracers or constants.
583 * - The tape should not contain the inputs, but it should contain the
584 * outputs.
585 * - The tape should also have only one array per primitive for multi-output
586 * primitives.
587 * - The constant_ids contains ids of arrays in the input list that are safe
588 * to treat as scalar constants.
589 */
590 explicit Compiled(
592 std::vector<array> inputs,
593 std::vector<array> outputs,
594 std::vector<array> tape,
595 std::unordered_set<uintptr_t> constant_ids);
596
597 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
598 override;
599 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
600 override;
601
604 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
605 void print(std::ostream& os) override;
606 bool is_equivalent(const Primitive& other) const override;
607
608 std::string lib_name() const {
609 return kernel_lib_;
610 }
611
612 private:
613 const std::vector<array> inputs_;
614 const std::vector<array> outputs_;
615 const std::vector<array> tape_;
616 const std::unordered_set<uintptr_t> constant_ids_;
617
618 std::string kernel_lib_;
619};
620
622 public:
623 explicit Concatenate(Stream stream, int axis)
624 : UnaryPrimitive(stream), axis_(axis) {}
625
626 void eval_cpu(const std::vector<array>& inputs, array& out) override;
627 void eval_gpu(const std::vector<array>& inputs, array& out) override;
628
632 bool is_equivalent(const Primitive& other) const override;
633 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
634 auto state() const {
635 return axis_;
636 }
637
638 private:
639 int axis_;
640};
641
642class Conjugate : public UnaryPrimitive {
643 public:
645
646 void eval_cpu(const std::vector<array>& inputs, array& out) override;
647 void eval_gpu(const std::vector<array>& inputs, array& out) override;
648
653};
654
656 public:
657 explicit Contiguous(Stream stream, bool allow_col_major)
658 : UnaryPrimitive(stream), allow_col_major_(allow_col_major) {}
659
660 void eval_cpu(const std::vector<array>& inputs, array& out) override;
661 void eval_gpu(const std::vector<array>& inputs, array& out) override;
662
667
668 bool is_equivalent(const Primitive& other) const override;
669
670 private:
671 bool allow_col_major_;
672};
673
675 public:
676 explicit Convolution(
678 const std::vector<int>& kernel_strides,
679 const std::vector<int>& padding,
680 const std::vector<int>& kernel_dilation,
681 const std::vector<int>& input_dilation,
682 const int groups = 1,
683 const bool flip = false)
685 padding_(padding),
686 kernel_strides_(kernel_strides),
687 kernel_dilation_(kernel_dilation),
688 input_dilation_(input_dilation),
689 groups_(groups),
690 flip_(flip) {}
691
692 void eval_cpu(const std::vector<array>& inputs, array& out) override;
693 void eval_gpu(const std::vector<array>& inputs, array& out) override;
694
695 std::vector<array> vjp(
696 const std::vector<array>& primals,
697 const std::vector<array>& cotangents,
698 const std::vector<int>& argnums,
699 const std::vector<array>& outputs) override;
700
702 bool is_equivalent(const Primitive& other) const override;
703 auto state() const {
704 return std::make_tuple(
705 padding_,
706 kernel_strides_,
707 kernel_dilation_,
708 input_dilation_,
709 groups_,
710 flip_);
711 }
712
713 private:
714 std::vector<int> padding_;
715 std::vector<int> kernel_strides_;
716 std::vector<int> kernel_dilation_;
717 std::vector<int> input_dilation_;
718 int groups_;
719 bool flip_;
720};
721
722class Copy : public UnaryPrimitive {
723 public:
725
726 void eval_cpu(const std::vector<array>& inputs, array& out) override;
727 void eval_gpu(const std::vector<array>& inputs, array& out) override;
728
734
735 private:
736 void eval(const std::vector<array>& inputs, array& out);
737};
738
739class Cos : public UnaryPrimitive {
740 public:
742
743 void eval_cpu(const std::vector<array>& inputs, array& out) override;
744 void eval_gpu(const std::vector<array>& inputs, array& out) override;
745
751};
752
753class Cosh : public UnaryPrimitive {
754 public:
756
757 void eval_cpu(const std::vector<array>& inputs, array& out) override;
758 void eval_gpu(const std::vector<array>& inputs, array& out) override;
759
765};
766
768 public:
771 int num_outputs,
772 std::function<std::vector<array>(
773 const std::vector<array>&,
774 const std::vector<array>&,
775 const std::vector<array>&)> vjp,
776 std::function<std::vector<array>(
777 const std::vector<array>&,
778 const std::vector<array>&,
779 const std::vector<int>&)> jvp,
780 std::function<std::pair<std::vector<array>, std::vector<int>>(
781 const std::vector<array>&,
782 const std::vector<int>&)> vmap)
783 : Primitive(stream),
784 num_outputs_(num_outputs),
785 vjp_fun_(std::move(vjp)),
786 jvp_fun_(std::move(jvp)),
787 vmap_fun_(std::move(vmap)) {}
788
789 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
790 override;
791 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
792 override;
793
797
798 private:
799 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
800
801 int num_outputs_;
802
803 std::function<std::vector<array>(
804 const std::vector<array>&,
805 const std::vector<array>&,
806 const std::vector<array>&)>
807 vjp_fun_;
808 std::function<std::vector<array>(
809 const std::vector<array>&,
810 const std::vector<array>&,
811 const std::vector<int>&)>
812 jvp_fun_;
813 std::function<std::pair<std::vector<array>, std::vector<int>>(
814 const std::vector<array>&,
815 const std::vector<int>&)>
816 vmap_fun_;
817};
818
819class Depends : public Primitive {
820 public:
822
823 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
824 override;
825 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
826 override;
827
828 std::vector<array> vjp(
829 const std::vector<array>& primals,
830 const std::vector<array>& cotan,
831 const std::vector<int>& argnums,
832 const std::vector<array>& outputs) override;
833
835
836 private:
837 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
838};
839
840class Divide : public UnaryPrimitive {
841 public:
843
844 void eval_cpu(const std::vector<array>& inputs, array& out) override;
845 void eval_gpu(const std::vector<array>& inputs, array& out) override;
846
852};
853
854class DivMod : public Primitive {
855 public:
857
858 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
859 override;
860 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
861 override;
862
867 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
868 return std::vector{inputs[0].shape(), inputs[0].shape()};
869 }
870};
871
872class Select : public UnaryPrimitive {
873 public:
875
876 void eval_cpu(const std::vector<array>& inputs, array& out) override;
877 void eval_gpu(const std::vector<array>& inputs, array& out) override;
878
884};
885
886class Remainder : public UnaryPrimitive {
887 public:
889
890 void eval_cpu(const std::vector<array>& inputs, array& out) override;
891 void eval_gpu(const std::vector<array>& inputs, array& out) override;
892
898};
899
900class Equal : public UnaryPrimitive {
901 public:
902 explicit Equal(Stream stream, bool equal_nan = false)
903 : UnaryPrimitive(stream), equal_nan_(equal_nan) {}
904
905 void eval_cpu(const std::vector<array>& inputs, array& out) override;
906 void eval_gpu(const std::vector<array>& inputs, array& out) override;
907
912
913 void print(std::ostream& os) override {
914 if (equal_nan_) {
915 os << "NaNEqual";
916 } else {
917 os << "Equal";
918 }
919 }
920 auto state() const {
921 return equal_nan_;
922 };
923
924 private:
925 bool equal_nan_;
926};
927
928class Erf : public UnaryPrimitive {
929 public:
931
932 void eval_cpu(const std::vector<array>& inputs, array& out) override;
933 void eval_gpu(const std::vector<array>& inputs, array& out) override;
934
940};
941
942class ErfInv : public UnaryPrimitive {
943 public:
945
946 void eval_cpu(const std::vector<array>& inputs, array& out) override;
947 void eval_gpu(const std::vector<array>& inputs, array& out) override;
948
954};
955
956class Exp : public UnaryPrimitive {
957 public:
959
960 void eval_cpu(const std::vector<array>& inputs, array& out) override;
961 void eval_gpu(const std::vector<array>& inputs, array& out) override;
962
968};
969
970class Expm1 : public UnaryPrimitive {
971 public:
973
974 void eval_cpu(const std::vector<array>& inputs, array& out) override;
975 void eval_gpu(const std::vector<array>& inputs, array& out) override;
976
981};
982
984 public:
985 explicit ExpandDims(Stream stream, std::vector<int> axes)
986 : UnaryPrimitive(stream), axes_(std::move(axes)) {}
987
988 void eval_cpu(const std::vector<array>& inputs, array& out) override;
989 void eval_gpu(const std::vector<array>& inputs, array& out) override;
990
994
995 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
996 bool is_equivalent(const Primitive& other) const override;
997
998 static Shape output_shape(const array& input, const std::vector<int>& axes);
999 auto state() const {
1000 return axes_;
1001 }
1002
1003 private:
1004 void eval(const std::vector<array>& inputs, array& out);
1005 std::vector<int> axes_;
1006};
1007
1008class FFT : public UnaryPrimitive {
1009 public:
1010 explicit FFT(
1011 Stream stream,
1012 const std::vector<size_t>& axes,
1013 bool inverse,
1014 bool real)
1015 : UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {}
1016
1017 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1018 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1019
1023
1024 bool is_equivalent(const Primitive& other) const override;
1025 auto state() const {
1026 return std::make_tuple(axes_, inverse_, real_);
1027 }
1028
1029 private:
1030 std::vector<size_t> axes_;
1031 bool inverse_;
1032 bool real_;
1033};
1034
1035class Flatten : public UnaryPrimitive {
1036 public:
1037 explicit Flatten(Stream stream, int start_axis, int end_axis)
1038 : UnaryPrimitive(stream), start_axis_(start_axis), end_axis_(end_axis) {}
1039
1040 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1041 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1042
1046 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1047 bool is_equivalent(const Primitive& other) const override;
1048
1049 static Shape output_shape(const array& input, int start_axis, int end_axis);
1050 auto state() const {
1051 return std::make_pair(start_axis_, end_axis_);
1052 }
1053
1054 private:
1055 int start_axis_;
1056 int end_axis_;
1057 void eval(const std::vector<array>& inputs, array& out);
1058};
1059
1060class Floor : public UnaryPrimitive {
1061 public:
1063
1064 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1065 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1066
1072};
1073
1074class Full : public UnaryPrimitive {
1075 public:
1077
1078 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1079 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1080
1085};
1086
1087class Gather : public UnaryPrimitive {
1088 public:
1089 explicit Gather(Stream stream, std::vector<int> axes, Shape slice_sizes)
1091 axes_(std::move(axes)),
1092 slice_sizes_(std::move(slice_sizes)) {}
1093
1094 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1095 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1096
1100 bool is_equivalent(const Primitive& other) const override;
1101 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1102 std::pair<std::vector<int>, std::vector<int>> state() const {
1103 return {axes_, slice_sizes_};
1104 }
1105
1106 private:
1107 std::vector<int> axes_;
1108 Shape slice_sizes_;
1109};
1110
1112 public:
1113 explicit GatherAxis(Stream stream, int axis)
1114 : UnaryPrimitive(stream), axis_(axis) {}
1115
1116 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1117 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1118
1122 bool is_equivalent(const Primitive& other) const override;
1123 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1124 auto state() const {
1125 return axis_;
1126 }
1127
1128 private:
1129 int axis_;
1130};
1131
1132class Greater : public UnaryPrimitive {
1133 public:
1135
1136 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1137 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1138
1144};
1145
1147 public:
1149
1150 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1151 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1152
1158};
1159
1160class Hadamard : public UnaryPrimitive {
1161 public:
1162 explicit Hadamard(Stream stream, float scale)
1163 : UnaryPrimitive(stream), scale_(scale) {}
1164
1165 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1166 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1167
1172
1173 bool is_equivalent(const Primitive& other) const override;
1174 auto state() const {
1175 return scale_;
1176 }
1177
1178 private:
1179 float scale_;
1180};
1181
1182class Imag : public UnaryPrimitive {
1183 public:
1185
1186 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1187 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1188
1194};
1195
1196class Less : public UnaryPrimitive {
1197 public:
1199
1200 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1201 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1202
1208};
1209
1211 public:
1213
1214 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1215 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1216
1222};
1223
1224class Load : public UnaryPrimitive {
1225 public:
1226 explicit Load(
1227 Stream stream,
1228 std::shared_ptr<io::Reader> reader,
1229 size_t offset,
1230 bool swap_endianness = false)
1232 reader_(std::move(reader)),
1233 offset_(offset),
1234 swap_endianness_(swap_endianness) {
1235 if (stream.device == Device::gpu) {
1236 io_stream();
1237 }
1238 }
1239
1240 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1241 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1242
1244
1245 private:
1246 Stream& io_stream() {
1247 static Stream io_stream = new_stream(Device::cpu);
1248 return io_stream;
1249 };
1250 std::shared_ptr<io::Reader> reader_;
1251 size_t offset_;
1252 bool swap_endianness_;
1253};
1254
1255class Log : public UnaryPrimitive {
1256 public:
1257 enum Base { two, ten, e };
1258
1259 explicit Log(Stream stream, Base base)
1260 : UnaryPrimitive(stream), base_(base) {}
1261
1262 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1263 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1264
1269
1270 Base state() const {
1271 return base_;
1272 };
1273
1274 void print(std::ostream& os) override {
1275 switch (base_) {
1276 case e:
1277 os << "Log";
1278 break;
1279 case two:
1280 os << "Log2";
1281 break;
1282 case ten:
1283 os << "Log10";
1284 break;
1285 }
1286 }
1287
1288 private:
1289 Base base_;
1290};
1291
1292class Log1p : public UnaryPrimitive {
1293 public:
1295
1296 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1297 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1298
1303};
1304
1306 public:
1308
1309 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1310 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1311
1317};
1318
1320 public:
1322
1323 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1324 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1325
1331};
1332
1334 public:
1336
1337 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1338 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1339
1345};
1346
1348 public:
1350
1351 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1352 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1353
1359};
1360
1361class Matmul : public UnaryPrimitive {
1362 public:
1364
1365 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1366 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1367
1372 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1373};
1374
1375class Maximum : public UnaryPrimitive {
1376 public:
1378
1379 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1380 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1381
1387};
1388
1389class Minimum : public UnaryPrimitive {
1390 public:
1392
1393 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1394 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1395
1401};
1402
1403class Multiply : public UnaryPrimitive {
1404 public:
1406
1407 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1408 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1409
1415};
1416
1417class Negative : public UnaryPrimitive {
1418 public:
1420
1421 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1422 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1423
1429};
1430
1431class NotEqual : public UnaryPrimitive {
1432 public:
1434
1435 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1436 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1437
1443};
1444
1446 public:
1448 Stream stream,
1449 std::vector<int> axes,
1450 bool inverted,
1451 Dtype dtype)
1453 axes_(std::move(axes)),
1454 inverted_(inverted),
1455 dtype_(dtype) {}
1456
1457 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1458 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1459
1462 bool is_equivalent(const Primitive& other) const override;
1463 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
1464 return {{}};
1465 }
1466 std::tuple<std::vector<int>, bool, Dtype> state() const {
1467 return {axes_, inverted_, dtype_};
1468 }
1469
1470 private:
1471 std::vector<int> axes_;
1472 bool inverted_;
1473 Dtype dtype_;
1474
1475 void eval(const std::vector<array>& inputs, array& out);
1476};
1477
1478class Pad : public UnaryPrimitive {
1479 public:
1480 explicit Pad(
1481 Stream stream,
1482 const std::vector<int>& axes,
1483 const Shape& low_pad_size,
1484 const Shape& high_pad_size)
1486 axes_(axes),
1487 low_pad_size_(low_pad_size),
1488 high_pad_size_(high_pad_size) {}
1489
1490 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1491 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1492
1496 bool is_equivalent(const Primitive& other) const override;
1497 auto state() const {
1498 return std::make_tuple(axes_, low_pad_size_, high_pad_size_);
1499 }
1500
1501 private:
1502 std::vector<int> axes_;
1503 Shape low_pad_size_;
1504 Shape high_pad_size_;
1505};
1506
1508 public:
1509 explicit Partition(Stream stream, int kth, int axis)
1510 : UnaryPrimitive(stream), kth_(kth), axis_(axis) {}
1511
1512 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1513 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1514
1519 bool is_equivalent(const Primitive& other) const override;
1520 auto state() const {
1521 return std::make_pair(kth_, axis_);
1522 };
1523
1524 private:
1525 int kth_;
1526 int axis_;
1527};
1528
1529class Power : public UnaryPrimitive {
1530 public:
1532
1533 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1534 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1535
1541};
1542
1544 public:
1546 Stream stream,
1547 int group_size,
1548 int bits,
1549 bool transpose)
1551 group_size_(group_size),
1552 bits_(bits),
1553 transpose_(transpose) {}
1554
1555 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1556 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1557
1561 bool is_equivalent(const Primitive& other) const override;
1562 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1563 auto state() const {
1564 return std::make_tuple(group_size_, bits_, transpose_);
1565 }
1566
1567 private:
1568 int group_size_;
1569 int bits_;
1570 bool transpose_;
1571};
1572
1574 public:
1575 explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose)
1577 group_size_(group_size),
1578 bits_(bits),
1579 transpose_(transpose) {}
1580
1581 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1582 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1583
1587 bool is_equivalent(const Primitive& other) const override;
1588 auto state() const {
1589 return std::make_tuple(group_size_, bits_, transpose_);
1590 }
1591
1592 private:
1593 int group_size_;
1594 int bits_;
1595 bool transpose_;
1596};
1597
1599 public:
1600 explicit RandomBits(Stream stream, const Shape& shape, int width)
1601 : UnaryPrimitive(stream), shape_(shape), width_(width) {}
1602
1603 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1604 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1605
1608 bool is_equivalent(const Primitive& other) const override;
1609 std::pair<std::vector<int>, int> state() const {
1610 return {shape_, width_};
1611 };
1612
1613 private:
1614 Shape shape_;
1615 int width_;
1616};
1617
1618class Real : public UnaryPrimitive {
1619 public:
1621
1622 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1623 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1624
1630};
1631
1632class Reshape : public UnaryPrimitive {
1633 public:
1634 explicit Reshape(Stream stream, const Shape& shape)
1635 : UnaryPrimitive(stream), shape_(shape) {}
1636
1637 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1638 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1639
1643 bool is_equivalent(const Primitive& other) const override;
1644 std::vector<int> state() const {
1645 return shape_;
1646 };
1647 static Shape output_shape(const array& input, Shape shape);
1648 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1649
1650 private:
1651 Shape shape_;
1652};
1653
1654class Reduce : public UnaryPrimitive {
1655 public:
1656 enum ReduceType { And, Or, Sum, Prod, Min, Max };
1657
1658 explicit Reduce(
1659 Stream stream,
1660 ReduceType reduce_type,
1661 const std::vector<int>& axes)
1662 : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}
1663
1664 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1665 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1666
1668
1669 std::vector<array> vjp(
1670 const std::vector<array>& primals,
1671 const std::vector<array>& cotangents,
1672 const std::vector<int>& argnums,
1673 const std::vector<array>& outputs) override;
1674
1675 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1676
1677 void print(std::ostream& os) override {
1678 switch (reduce_type_) {
1679 case And:
1680 os << "And";
1681 break;
1682 case Or:
1683 os << "Or";
1684 break;
1685 case Sum:
1686 os << "Sum";
1687 break;
1688 case Prod:
1689 os << "Prod";
1690 break;
1691 case Min:
1692 os << "Min";
1693 break;
1694 case Max:
1695 os << "Max";
1696 break;
1697 }
1698 }
1699 bool is_equivalent(const Primitive& other) const override;
1700 std::pair<ReduceType, std::vector<int>> state() const {
1701 return {reduce_type_, axes_};
1702 };
1703
1704 private:
1705 ReduceType reduce_type_;
1706 std::vector<int> axes_;
1707};
1708
1709class Round : public UnaryPrimitive {
1710 public:
1712
1713 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1714 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1715
1721};
1722
1723class Scan : public UnaryPrimitive {
1724 public:
1726
1727 explicit Scan(
1728 Stream stream,
1729 ReduceType reduce_type,
1730 int axis,
1731 bool reverse,
1732 bool inclusive)
1734 reduce_type_(reduce_type),
1735 axis_(axis),
1736 reverse_(reverse),
1737 inclusive_(inclusive) {}
1738
1739 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1740 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1741
1744
1745 void print(std::ostream& os) override {
1746 os << "Cum";
1747 switch (reduce_type_) {
1748 case Sum:
1749 os << "Sum";
1750 break;
1751 case Prod:
1752 os << "Prod";
1753 break;
1754 case Min:
1755 os << "Min";
1756 break;
1757 case Max:
1758 os << "Max";
1759 break;
1760 }
1761 }
1762 bool is_equivalent(const Primitive& other) const override;
1763 auto state() const {
1764 return std::make_tuple(reduce_type_, axis_, reverse_, inclusive_);
1765 }
1766
1767 private:
1768 ReduceType reduce_type_;
1769 int axis_;
1770 bool reverse_;
1771 bool inclusive_;
1772};
1773
1774class Scatter : public UnaryPrimitive {
1775 public:
1777
1778 explicit Scatter(
1779 Stream stream,
1780 ReduceType reduce_type,
1781 const std::vector<int>& axes)
1782 : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}
1783
1784 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1785 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1786
1789
1790 void print(std::ostream& os) override {
1791 os << "Scatter";
1792 switch (reduce_type_) {
1793 case Sum:
1794 os << " Sum";
1795 break;
1796 case Prod:
1797 os << " Prod";
1798 break;
1799 case Min:
1800 os << " Min";
1801 break;
1802 case Max:
1803 os << " Max";
1804 break;
1805 case None:
1806 break;
1807 }
1808 }
1809 bool is_equivalent(const Primitive& other) const override;
1810 std::pair<ReduceType, std::vector<int>> state() const {
1811 return {reduce_type_, axes_};
1812 };
1813
1814 private:
1815 ReduceType reduce_type_;
1816 std::vector<int> axes_;
1817};
1818
1820 public:
1822
1823 explicit ScatterAxis(Stream stream, ReduceType reduce_type, int axis)
1824 : UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {}
1825
1826 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1827 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1828
1831
1832 void print(std::ostream& os) override {
1833 os << "ScatterAxis";
1834 switch (reduce_type_) {
1835 case Sum:
1836 os << " Sum";
1837 break;
1838 case None:
1839 break;
1840 }
1841 }
1842
1843 bool is_equivalent(const Primitive& other) const override;
1844 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1845 std::pair<ReduceType, int> state() const {
1846 return {reduce_type_, axis_};
1847 }
1848
1849 private:
1850 ReduceType reduce_type_;
1851 int axis_;
1852};
1853
1854class Sigmoid : public UnaryPrimitive {
1855 public:
1857
1858 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1859 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1860
1866};
1867
1868class Sign : public UnaryPrimitive {
1869 public:
1871
1872 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1873 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1874
1880};
1881
1882class Sin : public UnaryPrimitive {
1883 public:
1885
1886 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1887 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1888
1894};
1895
1896class Sinh : public UnaryPrimitive {
1897 public:
1899
1900 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1901 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1902
1908};
1909
1910class Slice : public UnaryPrimitive {
1911 public:
1912 explicit Slice(
1913 Stream stream,
1914 const Shape& start_indices,
1915 const Shape& end_indices,
1916 const Shape& strides)
1918 start_indices_(start_indices),
1919 end_indices_(end_indices),
1920 strides_(strides) {}
1921
1922 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1923 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1924
1928 bool is_equivalent(const Primitive& other) const override;
1929 auto state() const {
1930 return std::make_tuple(start_indices_, end_indices_, strides_);
1931 }
1932
1933 private:
1934 Shape start_indices_;
1935 Shape end_indices_;
1936 Shape strides_;
1937};
1938
1940 public:
1941 explicit SliceUpdate(
1942 Stream stream,
1943 const Shape& start_indices,
1944 const Shape& end_indices,
1945 const Shape& strides)
1947 start_indices_(start_indices),
1948 end_indices_(end_indices),
1949 strides_(strides) {}
1950
1951 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1952 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1953
1957 bool is_equivalent(const Primitive& other) const override;
1959 auto state() const {
1960 return std::make_tuple(start_indices_, end_indices_, strides_);
1961 }
1962
1963 private:
1964 Shape start_indices_;
1965 Shape end_indices_;
1966 Shape strides_;
1967};
1968
1970 public:
1971 explicit DynamicSlice(Stream stream, std::vector<int> axes, Shape slice_size)
1973 axes_(std::move(axes)),
1974 slice_size_(std::move(slice_size)) {}
1975
1976 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1977 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1978
1982 bool is_equivalent(const Primitive& other) const override;
1983 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1984 auto state() const {
1985 return std::make_pair(axes_, slice_size_);
1986 }
1987
1988 private:
1989 std::vector<int> axes_;
1990 Shape slice_size_;
1991};
1992
1994 public:
1995 explicit DynamicSliceUpdate(Stream stream, std::vector<int> axes)
1996 : UnaryPrimitive(stream), axes_(std::move(axes)) {}
1997
1998 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1999 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2000
2004 bool is_equivalent(const Primitive& other) const override;
2006 auto state() const {
2007 return axes_;
2008 }
2009
2010 private:
2011 std::vector<int> axes_;
2012};
2013
2014class Softmax : public UnaryPrimitive {
2015 public:
2016 explicit Softmax(Stream stream, bool precise)
2017 : UnaryPrimitive(stream), precise_(precise) {}
2018
2019 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2020 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2021
2026
2027 bool is_equivalent(const Primitive& other) const override;
2028 auto state() const {
2029 return precise_;
2030 };
2031
2032 private:
2033 bool precise_;
2034};
2035
2036class Sort : public UnaryPrimitive {
2037 public:
2038 explicit Sort(Stream stream, int axis)
2039 : UnaryPrimitive(stream), axis_(axis) {}
2040
2041 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2042 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2043
2048 bool is_equivalent(const Primitive& other) const override;
2049 auto state() const {
2050 return axis_;
2051 }
2052
2053 private:
2054 int axis_;
2055};
2056
2057class Split : public Primitive {
2058 public:
2059 explicit Split(Stream stream, const Shape& indices, int axis)
2060 : Primitive(stream), indices_(indices), axis_(axis) {}
2061
2062 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2063 override;
2064 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2065 override;
2066
2070 bool is_equivalent(const Primitive& other) const override;
2071 std::pair<std::vector<int>, int> state() const {
2072 return {indices_, axis_};
2073 };
2074
2075 private:
2076 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
2077
2078 Shape indices_;
2079 int axis_;
2080};
2081
2082class Square : public UnaryPrimitive {
2083 public:
2085
2086 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2087 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2088
2094};
2095
2096class Sqrt : public UnaryPrimitive {
2097 public:
2098 explicit Sqrt(Stream stream, bool recip = false)
2099 : UnaryPrimitive(stream), recip_(recip) {}
2100
2101 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2102 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2103
2107 bool is_equivalent(const Primitive& other) const override;
2108 auto state() const {
2109 return recip_;
2110 }
2111
2112 void print(std::ostream& os) override {
2113 if (recip_) {
2114 os << "Rsqrt";
2115 } else {
2116 os << "Sqrt";
2117 }
2118 }
2119
2120 private:
2121 bool recip_;
2122};
2123
2125 public:
2127
2128 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2129 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2130
2135
2136 private:
2137 void eval(const std::vector<array>& inputs, array& out);
2138};
2139
2140class Subtract : public UnaryPrimitive {
2141 public:
2143
2144 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2145 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2146
2152};
2153
2154class Squeeze : public UnaryPrimitive {
2155 public:
2156 explicit Squeeze(Stream stream, std::vector<int> axes)
2157 : UnaryPrimitive(stream), axes_(std::move(axes)) {}
2158
2159 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2160 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2161
2165
2166 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
2167 bool is_equivalent(const Primitive& other) const override;
2168
2169 static Shape output_shape(const array& input, const std::vector<int>& axes);
2170 auto state() const {
2171 return axes_;
2172 };
2173
2174 private:
2175 void eval(const std::vector<array>& inputs, array& out);
2176 std::vector<int> axes_;
2177};
2178
2179class Tan : public UnaryPrimitive {
2180 public:
2182
2183 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2184 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2185
2191};
2192
2193class Tanh : public UnaryPrimitive {
2194 public:
2196
2197 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2198 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2199
2205};
2206
2208 public:
2209 explicit Unflatten(Stream stream, int axis, Shape shape)
2210 : UnaryPrimitive(stream), axis_(axis), shape_(std::move(shape)) {}
2211
2212 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2213 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2214
2218 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
2219 bool is_equivalent(const Primitive& other) const override;
2220
2221 static Shape output_shape(const array& input, int axis, const Shape& shape);
2222 auto state() const {
2223 return std::make_pair(axis_, shape_);
2224 }
2225
2226 private:
2227 int axis_;
2228 Shape shape_;
2229 void eval(const std::vector<array>& inputs, array& out);
2230};
2231
2232class View : public UnaryPrimitive {
2233 public:
2234 explicit View(Stream stream, Dtype dtype)
2235 : UnaryPrimitive(stream), dtype_(dtype) {}
2236
2237 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2238 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2239
2241 void print(std::ostream& os) override;
2242 bool is_equivalent(const Primitive& other) const override;
2243 auto state() const {
2244 return dtype_;
2245 }
2246
2247 private:
2248 Dtype dtype_;
2249};
2250
2252 public:
2253 explicit Transpose(Stream stream, const std::vector<int>& axes)
2254 : UnaryPrimitive(stream), axes_(axes) {}
2255
2256 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2257 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2258
2262 bool is_equivalent(const Primitive& other) const override;
2263 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
2264 std::vector<int> state() const {
2265 return axes_;
2266 };
2267
2268 private:
2269 std::vector<int> axes_;
2270
2271 void eval(const std::vector<array>& inputs, array& out);
2272};
2273
2274/* QR Factorization primitive. */
2275class QRF : public Primitive {
2276 public:
2278
2279 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2280 override;
2281 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2282 override;
2283
2285};
2286
2287/* SVD primitive. */
2288class SVD : public Primitive {
2289 public:
2291
2292 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2293 override;
2294 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2295 override;
2296
2299};
2300
2301/* Matrix inversion primitive. */
2302class Inverse : public UnaryPrimitive {
2303 public:
2304 explicit Inverse(Stream stream, bool tri, bool upper)
2305 : UnaryPrimitive(stream), tri_(tri), upper_(upper) {}
2306
2307 void eval_cpu(const std::vector<array>& inputs, array& output) override;
2308 void eval_gpu(const std::vector<array>& inputs, array& output) override;
2309
2312 auto state() const {
2313 return std::make_pair(tri_, upper_);
2314 }
2315
2316 private:
2317 bool tri_;
2318 bool upper_;
2319};
2320
2321class Cholesky : public UnaryPrimitive {
2322 public:
2323 explicit Cholesky(Stream stream, bool upper)
2324 : UnaryPrimitive(stream), upper_(upper) {}
2325
2326 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2327 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2328 auto state() const {
2329 return upper_;
2330 }
2331
2334
2335 private:
2336 bool upper_;
2337};
2338
2339class Eigh : public Primitive {
2340 public:
2341 explicit Eigh(Stream stream, std::string uplo, bool compute_eigenvectors)
2342 : Primitive(stream),
2343 uplo_(std::move(uplo)),
2344 compute_eigenvectors_(compute_eigenvectors) {}
2345 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2346 override;
2347 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2348 override;
2349
2352
2353 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
2354
2355 bool is_equivalent(const Primitive& other) const override;
2356 auto state() const {
2357 return std::make_pair(uplo_, compute_eigenvectors_);
2358 }
2359
2360 private:
2361 std::string uplo_;
2362 bool compute_eigenvectors_;
2363};
2364
2365/* LU Factorization primitive. */
2366class LUF : public Primitive {
2367 public:
2369 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2370 override;
2371 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2372 override;
2373
2375};
2376
2377} // namespace mlx::core
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Abs(Stream stream)
Definition primitives.h:156
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Add(Stream stream)
Definition primitives.h:170
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::pair< float, float > state() const
Definition primitives.h:195
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
AddMM(Stream stream, float alpha, float beta)
Definition primitives.h:184
Arange(Stream stream, double start, double stop, double step)
Definition primitives.h:206
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::tuple< double, double, double > state() const
Definition primitives.h:215
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCos(Stream stream)
Definition primitives.h:227
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCosh(Stream stream)
Definition primitives.h:241
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcSin(Stream stream)
Definition primitives.h:255
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcSinh(Stream stream)
Definition primitives.h:269
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTan2(Stream stream)
Definition primitives.h:297
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcTan(Stream stream)
Definition primitives.h:283
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTanh(Stream stream)
Definition primitives.h:311
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::pair< int, int > state() const
Definition primitives.h:336
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArgPartition(Stream stream, int kth, int axis)
Definition primitives.h:325
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
ReduceType
Definition primitives.h:347
@ ArgMin
Definition primitives.h:348
@ ArgMax
Definition primitives.h:349
ArgReduce(Stream stream, ReduceType reduce_type, int axis)
Definition primitives.h:352
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::pair< ReduceType, int > state() const
Definition primitives.h:363
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
ArgSort(Stream stream, int axis)
Definition primitives.h:374
int state() const
Definition primitives.h:384
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:427
AsStrided(Stream stream, Shape shape, Strides strides, size_t offset)
Definition primitives.h:415
void eval_gpu(const std::vector< array > &inputs, array &out) override
AsType(Stream stream, Dtype dtype)
Definition primitives.h:394
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Dtype state() const
Definition primitives.h:405
void eval_cpu(const std::vector< array > &inputs, array &out) override
BitwiseBinary(Stream stream, Op op)
Definition primitives.h:443
void eval_cpu(const std::vector< array > &inputs, array &out) override
void print(std::ostream &os) override
Print the primitive.
Op
Definition primitives.h:441
@ RightShift
Definition primitives.h:441
@ Or
Definition primitives.h:441
@ LeftShift
Definition primitives.h:441
@ And
Definition primitives.h:441
@ Xor
Definition primitives.h:441
auto state() const
Definition primitives.h:454
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
BitwiseInvert(Stream stream)
Definition primitives.h:464
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
auto state() const
Definition primitives.h:491
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
BlockMaskedMM(Stream stream, int block_size)
Definition primitives.h:477
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
BroadcastAxes(Stream stream, std::vector< int > ignore_axes={})
Definition primitives.h:518
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:532
void eval_cpu(const std::vector< array > &inputs, array &out) override
static Shape output_shape(const std::vector< array > &inputs, const std::vector< int > &ignore_axes)
Broadcast(Stream stream, const Shape &shape)
Definition primitives.h:543
static Shape output_shape(const std::vector< array > &inputs)
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< int > state() const
Definition primitives.h:555
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Ceil(Stream stream)
Definition primitives.h:567
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2328
Cholesky(Stream stream, bool upper)
Definition primitives.h:2323
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void print(std::ostream &os) override
Print the primitive.
Compiled(Stream stream, std::vector< array > inputs, std::vector< array > outputs, std::vector< array > tape, std::unordered_set< uintptr_t > constant_ids)
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
std::string lib_name() const
Definition primitives.h:608
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:634
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Concatenate(Stream stream, int axis)
Definition primitives.h:623
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Conjugate(Stream stream)
Definition primitives.h:644
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Contiguous(Stream stream, bool allow_col_major)
Definition primitives.h:657
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
Convolution(Stream stream, const std::vector< int > &kernel_strides, const std::vector< int > &padding, const std::vector< int > &kernel_dilation, const std::vector< int > &input_dilation, const int groups=1, const bool flip=false)
Definition primitives.h:676
auto state() const
Definition primitives.h:703
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
Copy(Stream stream)
Definition primitives.h:724
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Cos(Stream stream)
Definition primitives.h:741
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Cosh(Stream stream)
Definition primitives.h:755
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
virtual std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes) override
The primitive must know how to vectorize itself across the given axes.
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
The Jacobian-vector product.
CustomTransforms(Stream stream, int num_outputs, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> vjp, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< int > &)> jvp, std::function< std::pair< std::vector< array >, std::vector< int > >(const std::vector< array > &, const std::vector< int > &)> vmap)
Definition primitives.h:769
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotan, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Depends(Stream stream)
Definition primitives.h:821
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:867
DivMod(Stream stream)
Definition primitives.h:856
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Divide(Stream stream)
Definition primitives.h:842
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
DynamicSlice(Stream stream, std::vector< int > axes, Shape slice_size)
Definition primitives.h:1971
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1984
auto state() const
Definition primitives.h:2006
DynamicSliceUpdate(Stream stream, std::vector< int > axes)
Definition primitives.h:1995
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
auto state() const
Definition primitives.h:2356
Eigh(Stream stream, std::string uplo, bool compute_eigenvectors)
Definition primitives.h:2341
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:913
Equal(Stream stream, bool equal_nan=false)
Definition primitives.h:902
auto state() const
Definition primitives.h:920
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Erf(Stream stream)
Definition primitives.h:930
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ErfInv(Stream stream)
Definition primitives.h:944
void eval_cpu(const std::vector< array > &inputs, array &out) override
Exp(Stream stream)
Definition primitives.h:958
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
static Shape output_shape(const array &input, const std::vector< int > &axes)
auto state() const
Definition primitives.h:999
void eval_gpu(const std::vector< array > &inputs, array &out) override
ExpandDims(Stream stream, std::vector< int > axes)
Definition primitives.h:985
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Expm1(Stream stream)
Definition primitives.h:972
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
FFT(Stream stream, const std::vector< size_t > &axes, bool inverse, bool real)
Definition primitives.h:1010
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1025
static Shape output_shape(const array &input, int start_axis, int end_axis)
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
Flatten(Stream stream, int start_axis, int end_axis)
Definition primitives.h:1037
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1050
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Floor(Stream stream)
Definition primitives.h:1062
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Full(Stream stream)
Definition primitives.h:1076
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
GatherAxis(Stream stream, int axis)
Definition primitives.h:1113
auto state() const
Definition primitives.h:1124
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::pair< std::vector< int >, std::vector< int > > state() const
Definition primitives.h:1102
Gather(Stream stream, std::vector< int > axes, Shape slice_sizes)
Definition primitives.h:1089
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, array &out) override
GatherMM(Stream stream)
Definition primitives.h:501
auto state() const
Definition primitives.h:1588
GatherQMM(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1575
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
GreaterEqual(Stream stream)
Definition primitives.h:1148
void eval_gpu(const std::vector< array > &inputs, array &out) override
Greater(Stream stream)
Definition primitives.h:1134
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
Hadamard(Stream stream, float scale)
Definition primitives.h:1162
auto state() const
Definition primitives.h:1174
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Imag(Stream stream)
Definition primitives.h:1184
void eval_gpu(const std::vector< array > &inputs, array &output) override
Inverse(Stream stream, bool tri, bool upper)
Definition primitives.h:2304
auto state() const
Definition primitives.h:2312
void eval_cpu(const std::vector< array > &inputs, array &output) override
LUF(Stream stream)
Definition primitives.h:2368
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
LessEqual(Stream stream)
Definition primitives.h:1212
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Less(Stream stream)
Definition primitives.h:1198
void eval_gpu(const std::vector< array > &inputs, array &out) override
Load(Stream stream, std::shared_ptr< io::Reader > reader, size_t offset, bool swap_endianness=false)
Definition primitives.h:1226
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Log1p(Stream stream)
Definition primitives.h:1294
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogAddExp(Stream stream)
Definition primitives.h:1349
Base
Definition primitives.h:1257
@ ten
Definition primitives.h:1257
@ two
Definition primitives.h:1257
@ e
Definition primitives.h:1257
Log(Stream stream, Base base)
Definition primitives.h:1259
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1274
Base state() const
Definition primitives.h:1270
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalAnd(Stream stream)
Definition primitives.h:1321
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalNot(Stream stream)
Definition primitives.h:1307
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
LogicalOr(Stream stream)
Definition primitives.h:1335
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Matmul(Stream stream)
Definition primitives.h:1363
Maximum(Stream stream)
Definition primitives.h:1377
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Minimum(Stream stream)
Definition primitives.h:1391
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Multiply(Stream stream)
Definition primitives.h:1405
void eval_gpu(const std::vector< array > &inputs, array &out) override
Negative(Stream stream)
Definition primitives.h:1419
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
NotEqual(Stream stream)
Definition primitives.h:1433
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:1463
NumberOfElements(Stream stream, std::vector< int > axes, bool inverted, Dtype dtype)
Definition primitives.h:1447
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::tuple< std::vector< int >, bool, Dtype > state() const
Definition primitives.h:1466
auto state() const
Definition primitives.h:1497
Pad(Stream stream, const std::vector< int > &axes, const Shape &low_pad_size, const Shape &high_pad_size)
Definition primitives.h:1480
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Partition(Stream stream, int kth, int axis)
Definition primitives.h:1509
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
auto state() const
Definition primitives.h:1520
void eval_cpu(const std::vector< array > &inputs, array &out) override
Power(Stream stream)
Definition primitives.h:1531
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:48
virtual void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
virtual std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
The vector-Jacobian product.
virtual ~Primitive()=default
Primitive(const Primitive &other)=delete
Primitive(Primitive &&other)=delete
const Stream & stream()
The stream the primitive will run on.
Definition primitives.h:58
Primitive & operator=(Primitive &&other)=delete
virtual bool is_equivalent(const Primitive &other) const
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:107
Primitive & operator=(const Primitive &other)=delete
const Device & device()
The device the primitive will run on.
Definition primitives.h:53
virtual std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
The Jacobian-vector product.
virtual std::vector< Shape > output_shapes(const std::vector< array > &inputs)
Get the output shapes of the primitive.
virtual std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes)
The primitive must know how to vectorize itself across the given axes.
virtual void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
virtual void print(std::ostream &os)=0
Print the primitive.
Primitive(Stream stream)
Definition primitives.h:50
QRF(Stream stream)
Definition primitives.h:2277
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
QuantizedMatmul(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1545
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1563
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::pair< std::vector< int >, int > state() const
Definition primitives.h:1609
RandomBits(Stream stream, const Shape &shape, int width)
Definition primitives.h:1600
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Real(Stream stream)
Definition primitives.h:1620
Reduce(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1658
ReduceType
Definition primitives.h:1656
@ Min
Definition primitives.h:1656
@ Or
Definition primitives.h:1656
@ Max
Definition primitives.h:1656
@ And
Definition primitives.h:1656
@ Sum
Definition primitives.h:1656
@ Prod
Definition primitives.h:1656
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1677
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::pair< ReduceType, std::vector< int > > state() const
Definition primitives.h:1700
Remainder(Stream stream)
Definition primitives.h:888
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
static Shape output_shape(const array &input, Shape shape)
void eval_gpu(const std::vector< array > &inputs, array &out) override
Reshape(Stream stream, const Shape &shape)
Definition primitives.h:1634
std::vector< int > state() const
Definition primitives.h:1644
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Round(Stream stream)
Definition primitives.h:1711
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
SVD(Stream stream)
Definition primitives.h:2290
void eval_cpu(const std::vector< array > &inputs, array &out) override
ReduceType
Definition primitives.h:1725
@ Prod
Definition primitives.h:1725
@ Min
Definition primitives.h:1725
@ Max
Definition primitives.h:1725
@ Sum
Definition primitives.h:1725
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
auto state() const
Definition primitives.h:1763
Scan(Stream stream, ReduceType reduce_type, int axis, bool reverse, bool inclusive)
Definition primitives.h:1727
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1745
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::pair< ReduceType, int > state() const
Definition primitives.h:1845
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1832
void eval_gpu(const std::vector< array > &inputs, array &out) override
ScatterAxis(Stream stream, ReduceType reduce_type, int axis)
Definition primitives.h:1823
ReduceType
Definition primitives.h:1821
@ Sum
Definition primitives.h:1821
@ None
Definition primitives.h:1821
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::pair< ReduceType, std::vector< int > > state() const
Definition primitives.h:1810
ReduceType
Definition primitives.h:1776
@ Sum
Definition primitives.h:1776
@ Max
Definition primitives.h:1776
@ Prod
Definition primitives.h:1776
@ None
Definition primitives.h:1776
@ Min
Definition primitives.h:1776
void eval_cpu(const std::vector< array > &inputs, array &out) override
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1790
void eval_gpu(const std::vector< array > &inputs, array &out) override
Scatter(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1778
void eval_gpu(const std::vector< array > &inputs, array &out) override
Select(Stream stream)
Definition primitives.h:874
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sigmoid(Stream stream)
Definition primitives.h:1856
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Sign(Stream stream)
Definition primitives.h:1870
Sin(Stream stream)
Definition primitives.h:1884
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sinh(Stream stream)
Definition primitives.h:1898
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1929
Slice(Stream stream, const Shape &start_indices, const Shape &end_indices, const Shape &strides)
Definition primitives.h:1912
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
SliceUpdate(Stream stream, const Shape &start_indices, const Shape &end_indices, const Shape &strides)
Definition primitives.h:1941
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1959
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Softmax(Stream stream, bool precise)
Definition primitives.h:2016
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2028
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2049
Sort(Stream stream, int axis)
Definition primitives.h:2038
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
std::pair< std::vector< int >, int > state() const
Definition primitives.h:2071
Split(Stream stream, const Shape &indices, int axis)
Definition primitives.h:2059
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
auto state() const
Definition primitives.h:2108
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sqrt(Stream stream, bool recip=false)
Definition primitives.h:2098
void eval_gpu(const std::vector< array > &inputs, array &out) override
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:2112
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Square(Stream stream)
Definition primitives.h:2084
Squeeze(Stream stream, std::vector< int > axes)
Definition primitives.h:2156
auto state() const
Definition primitives.h:2170
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
static Shape output_shape(const array &input, const std::vector< int > &axes)
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
StopGradient(Stream stream)
Definition primitives.h:2126
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Subtract(Stream stream)
Definition primitives.h:2142
Tan(Stream stream)
Definition primitives.h:2181
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Tanh(Stream stream)
Definition primitives.h:2195
void eval_cpu(const std::vector< array > &inputs, array &out) override
Transpose(Stream stream, const std::vector< int > &axes)
Definition primitives.h:2253
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< int > state() const
Definition primitives.h:2264
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:126
UnaryPrimitive & operator=(const UnaryPrimitive &other)=delete
UnaryPrimitive(Stream stream)
An abstract base class for a primitive with a single output.
Definition primitives.h:131
virtual void eval_gpu(const std::vector< array > &inputs, array &output)=0
UnaryPrimitive(UnaryPrimitive &&other)=delete
virtual void eval_cpu(const std::vector< array > &inputs, array &output)=0
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:141
UnaryPrimitive(const UnaryPrimitive &other)=delete
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:136
UnaryPrimitive & operator=(UnaryPrimitive &&other)=delete
virtual ~UnaryPrimitive()=default
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Unflatten(Stream stream, int axis, Shape shape)
Definition primitives.h:2209
static Shape output_shape(const array &input, int axis, const Shape &shape)
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2222
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2243
void print(std::ostream &os) override
Print the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
View(Stream stream, Dtype dtype)
Definition primitives.h:2234
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition array.h:24
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
array tri(int n, int m, int k, Dtype type, StreamOrDevice s={})
array transpose(const array &a, std::vector< int > axes, StreamOrDevice s={})
Permutes the dimensions according to the given axes.
array real(const array &a, StreamOrDevice s={})
Definition allocator.h:7
std::vector< ShapeElem > Shape
Definition array.h:21
Stream new_stream(Device d)
Make a new stream on the given device.
std::vector< int64_t > Strides
Definition array.h:22
void eval(std::vector< array > outputs)
#define DEFINE_DEFAULT_IS_EQUIVALENT()
Definition primitives.h:34
#define DEFINE_PRINT(PRIMITIVE)
Definition primitives.h:29
#define DEFINE_INPUT_OUTPUT_SHAPE()
Definition primitives.h:39
#define DEFINE_GRADS()
Definition primitives.h:17
#define DEFINE_VMAP()
Definition primitives.h:12
Definition device.h:7
static constexpr DeviceType gpu
Definition device.h:14
static constexpr DeviceType cpu
Definition device.h:13
Definition dtype.h:13
Definition stream.h:9
Device device
Definition stream.h:11