MLX
 
Loading...
Searching...
No Matches
primitives.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <unordered_set>
6
7#include "mlx/array.h"
8#include "mlx/device.h"
9#include "mlx/io/load.h"
10#include "mlx/stream.h"
11
12#define DEFINE_VMAP() \
13 virtual std::pair<std::vector<array>, std::vector<int>> vmap( \
14 const std::vector<array>& inputs, const std::vector<int>& axes) \
15 override;
16
17#define DEFINE_GRADS() \
18 std::vector<array> jvp( \
19 const std::vector<array>& primals, \
20 const std::vector<array>& tangents, \
21 const std::vector<int>& argnums) override; \
22 \
23 std::vector<array> vjp( \
24 const std::vector<array>& primals, \
25 const std::vector<array>& cotangents, \
26 const std::vector<int>& argnums, \
27 const std::vector<array>& outputs) override;
28
29#define DEFINE_PRINT(PRIMITIVE) \
30 void print(std::ostream& os) override { \
31 os << #PRIMITIVE; \
32 }
33
34#define DEFINE_DEFAULT_IS_EQUIVALENT() \
35 bool is_equivalent(const Primitive& other) const override { \
36 return true; \
37 }
38
39#define DEFINE_INPUT_OUTPUT_SHAPE() \
40 std::vector<Shape> output_shapes(const std::vector<array>& inputs) \
41 override { \
42 return {inputs[0].shape()}; \
43 }
44
45namespace mlx::core {
46
47// Abstract base class
48class Primitive {
49 public:
50 explicit Primitive(Stream stream) : stream_(stream) {}
51
53 const Device& device() {
54 return stream().device;
55 }
56
58 const Stream& stream() {
59 return stream_;
60 }
61
69 virtual void eval_cpu(
70 const std::vector<array>& inputs,
71 std::vector<array>& outputs) = 0;
72 virtual void eval_gpu(
73 const std::vector<array>& inputs,
74 std::vector<array>& outputs) = 0;
75
79 virtual std::vector<array> jvp(
80 const std::vector<array>& primals,
81 const std::vector<array>& tangents,
82 const std::vector<int>& argnums);
83
87 virtual std::vector<array> vjp(
88 const std::vector<array>& primals,
89 const std::vector<array>& cotangents,
90 const std::vector<int>& argnums,
91 const std::vector<array>& outputs);
92
99 virtual std::pair<std::vector<array>, std::vector<int>> vmap(
100 const std::vector<array>& inputs,
101 const std::vector<int>& axes);
102
104 virtual void print(std::ostream& os) = 0;
105
107 virtual bool is_equivalent(const Primitive& other) const {
108 return false;
109 }
110
113 virtual std::vector<Shape> output_shapes(const std::vector<array>& inputs);
114
115 virtual ~Primitive() = default;
116 Primitive(const Primitive& other) = delete;
117 Primitive(Primitive&& other) = delete;
118 Primitive& operator=(const Primitive& other) = delete;
119 Primitive& operator=(Primitive&& other) = delete;
120
121 private:
122 // Every primitive stores the stream it should run in
123 Stream stream_;
124};
125
126class UnaryPrimitive : public Primitive {
130 public:
132
133 virtual void eval_cpu(const std::vector<array>& inputs, array& output) = 0;
134 virtual void eval_gpu(const std::vector<array>& inputs, array& output) = 0;
135
136 inline void eval_cpu(
137 const std::vector<array>& inputs,
138 std::vector<array>& outputs) override {
139 eval_cpu(inputs, outputs[0]);
140 }
141 inline void eval_gpu(
142 const std::vector<array>& inputs,
143 std::vector<array>& outputs) override {
144 eval_gpu(inputs, outputs[0]);
145 }
146
147 virtual ~UnaryPrimitive() = default;
148 UnaryPrimitive(const UnaryPrimitive& other) = delete;
150 UnaryPrimitive& operator=(const UnaryPrimitive& other) = delete;
152};
153
154class Abs : public UnaryPrimitive {
155 public:
157
158 void eval_cpu(const std::vector<array>& inputs, array& out) override;
159 void eval_gpu(const std::vector<array>& inputs, array& out) override;
160
166};
167
168class Add : public UnaryPrimitive {
169 public:
171
172 void eval_cpu(const std::vector<array>& inputs, array& out) override;
173 void eval_gpu(const std::vector<array>& inputs, array& out) override;
174
180};
181
182class AddMM : public UnaryPrimitive {
183 public:
184 explicit AddMM(Stream stream, float alpha, float beta)
185 : UnaryPrimitive(stream), alpha_(alpha), beta_(beta) {}
186
187 void eval_cpu(const std::vector<array>& inputs, array& out) override;
188 void eval_gpu(const std::vector<array>& inputs, array& out) override;
189
193
194 bool is_equivalent(const Primitive& other) const override;
195 std::pair<float, float> state() const {
196 return {alpha_, beta_};
197 };
198
199 private:
200 const float alpha_;
201 const float beta_;
202};
203
204class Arange : public UnaryPrimitive {
205 public:
206 explicit Arange(Stream stream, double start, double stop, double step)
207 : UnaryPrimitive(stream), start_(start), stop_(stop), step_(step) {}
208
209 void eval_cpu(const std::vector<array>& inputs, array& out) override;
210 void eval_gpu(const std::vector<array>& inputs, array& out) override;
211
213 bool is_equivalent(const Primitive& other) const override;
214 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
215 std::tuple<double, double, double> state() const {
216 return {start_, stop_, step_};
217 };
218
219 private:
220 double start_;
221 double stop_;
222 double step_;
223};
224
225class ArcCos : public UnaryPrimitive {
226 public:
228
229 void eval_cpu(const std::vector<array>& inputs, array& out) override;
230 void eval_gpu(const std::vector<array>& inputs, array& out) override;
231
237};
238
239class ArcCosh : public UnaryPrimitive {
240 public:
242
243 void eval_cpu(const std::vector<array>& inputs, array& out) override;
244 void eval_gpu(const std::vector<array>& inputs, array& out) override;
245
251};
252
253class ArcSin : public UnaryPrimitive {
254 public:
256
257 void eval_cpu(const std::vector<array>& inputs, array& out) override;
258 void eval_gpu(const std::vector<array>& inputs, array& out) override;
259
265};
266
267class ArcSinh : public UnaryPrimitive {
268 public:
270
271 void eval_cpu(const std::vector<array>& inputs, array& out) override;
272 void eval_gpu(const std::vector<array>& inputs, array& out) override;
273
279};
280
281class ArcTan : public UnaryPrimitive {
282 public:
284
285 void eval_cpu(const std::vector<array>& inputs, array& out) override;
286 void eval_gpu(const std::vector<array>& inputs, array& out) override;
287
293};
294
295class ArcTan2 : public UnaryPrimitive {
296 public:
298
299 void eval_cpu(const std::vector<array>& inputs, array& out) override;
300 void eval_gpu(const std::vector<array>& inputs, array& out) override;
301
307};
308
309class ArcTanh : public UnaryPrimitive {
310 public:
312
313 void eval_cpu(const std::vector<array>& inputs, array& out) override;
314 void eval_gpu(const std::vector<array>& inputs, array& out) override;
315
321};
322
324 public:
325 explicit ArgPartition(Stream stream, int kth, int axis)
326 : UnaryPrimitive(stream), kth_(kth), axis_(axis) {}
327
328 void eval_cpu(const std::vector<array>& inputs, array& out) override;
329 void eval_gpu(const std::vector<array>& inputs, array& out) override;
330
335 bool is_equivalent(const Primitive& other) const override;
336 std::pair<int, int> state() const {
337 return {kth_, axis_};
338 };
339
340 private:
341 int kth_;
342 int axis_;
343};
344
345class ArgReduce : public UnaryPrimitive {
346 public:
351
352 explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis)
353 : UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {}
354
355 void eval_cpu(const std::vector<array>& inputs, array& out) override;
356 void eval_gpu(const std::vector<array>& inputs, array& out) override;
357
361 bool is_equivalent(const Primitive& other) const override;
362 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
363 std::pair<ReduceType, int> state() const {
364 return {reduce_type_, axis_};
365 };
366
367 private:
368 ReduceType reduce_type_;
369 int axis_;
370};
371
372class ArgSort : public UnaryPrimitive {
373 public:
374 explicit ArgSort(Stream stream, int axis)
375 : UnaryPrimitive(stream), axis_(axis) {}
376
377 void eval_cpu(const std::vector<array>& inputs, array& out) override;
378 void eval_gpu(const std::vector<array>& inputs, array& out) override;
379
383 bool is_equivalent(const Primitive& other) const override;
384 int state() const {
385 return axis_;
386 };
387
388 private:
389 int axis_;
390};
391
392class AsType : public UnaryPrimitive {
393 public:
394 explicit AsType(Stream stream, Dtype dtype)
395 : UnaryPrimitive(stream), dtype_(dtype) {}
396
397 void eval_cpu(const std::vector<array>& inputs, array& out) override;
398 void eval_gpu(const std::vector<array>& inputs, array& out) override;
399
404 bool is_equivalent(const Primitive& other) const override;
405 Dtype state() const {
406 return dtype_;
407 };
408
409 private:
410 Dtype dtype_;
411};
412
413class AsStrided : public UnaryPrimitive {
414 public:
415 explicit AsStrided(Stream stream, Shape shape, Strides strides, size_t offset)
417 shape_(std::move(shape)),
418 strides_(std::move(strides)),
419 offset_(offset) {}
420
421 void eval_cpu(const std::vector<array>& inputs, array& out) override;
422 void eval_gpu(const std::vector<array>& inputs, array& out) override;
423
426 bool is_equivalent(const Primitive& other) const override;
427 auto state() const {
428 return std::make_tuple(shape_, strides_, offset_);
429 }
430
431 private:
432 Shape shape_;
433 Strides strides_;
434 size_t offset_;
435
436 void eval(const std::vector<array>& inputs, array& out);
437};
438
440 public:
442
444 : UnaryPrimitive(stream), op_(op) {}
445
446 void eval_cpu(const std::vector<array>& inputs, array& out) override;
447 void eval_gpu(const std::vector<array>& inputs, array& out) override;
448
451 bool is_equivalent(const Primitive& other) const override;
452 void print(std::ostream& os) override;
454 auto state() const {
455 return op_;
456 }
457
458 private:
459 Op op_;
460};
461
463 public:
464 explicit BlockMaskedMM(Stream stream, int block_size)
465 : UnaryPrimitive(stream), block_size_(block_size) {}
466
467 void eval_cpu(const std::vector<array>& inputs, array& out) override;
468 void eval_gpu(const std::vector<array>& inputs, array& out) override;
469
470 std::vector<array> vjp(
471 const std::vector<array>& primals,
472 const std::vector<array>& cotangents,
473 const std::vector<int>& argnums,
474 const std::vector<array>& outputs) override;
475
477 bool is_equivalent(const Primitive& other) const override;
478 auto state() const {
479 return block_size_;
480 }
481
482 private:
483 int block_size_;
484};
485
486class GatherMM : public UnaryPrimitive {
487 public:
489
490 void eval_cpu(const std::vector<array>& inputs, array& out) override;
491 void eval_gpu(const std::vector<array>& inputs, array& out) override;
492
493 std::vector<array> vjp(
494 const std::vector<array>& primals,
495 const std::vector<array>& cotangents,
496 const std::vector<int>& argnums,
497 const std::vector<array>& outputs) override;
498
501};
502
504 public:
505 explicit BroadcastAxes(Stream stream, std::vector<int> ignore_axes = {})
506 : UnaryPrimitive(stream), ignore_axes_(std::move(ignore_axes)) {}
507
508 void eval_cpu(const std::vector<array>& inputs, array& out) override;
509 void eval_gpu(const std::vector<array>& inputs, array& out) override;
510
514 bool is_equivalent(const Primitive& other) const override;
516 const std::vector<array>& inputs,
517 const std::vector<int>& ignore_axes);
518 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
519 auto state() const {
520 return ignore_axes_;
521 }
522
523 private:
524 void eval(const std::vector<array>& inputs, array& out);
525 std::vector<int> ignore_axes_;
526};
527
528class Broadcast : public UnaryPrimitive {
529 public:
530 explicit Broadcast(Stream stream, const Shape& shape)
531 : UnaryPrimitive(stream), shape_(shape) {}
532
533 void eval_cpu(const std::vector<array>& inputs, array& out) override;
534 void eval_gpu(const std::vector<array>& inputs, array& out) override;
535
539 static Shape output_shape(const std::vector<array>& inputs);
540 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
541 bool is_equivalent(const Primitive& other) const override;
542 std::vector<int> state() const {
543 return shape_;
544 };
545
546 private:
547 Shape shape_;
548
549 void eval(const std::vector<array>& inputs, array& out);
550};
551
552class Ceil : public UnaryPrimitive {
553 public:
555
556 void eval_cpu(const std::vector<array>& inputs, array& out) override;
557 void eval_gpu(const std::vector<array>& inputs, array& out) override;
558
564};
565
566class Compiled : public Primitive {
567 public:
568 /*
569 * The inputs, outputs and tape are either tracers or constants.
570 * - The tape should not contain the inputs, but it should contain the
571 * outputs.
572 * - The tape should also have only one array per primitive for multi-output
573 * primitives.
574 * - The constant_ids contains ids of arrays in the input list that are safe
575 * to treat as scalar constants.
576 */
577 explicit Compiled(
579 std::vector<array> inputs,
580 std::vector<array> outputs,
581 std::vector<array> tape,
582 std::unordered_set<uintptr_t> constant_ids);
583
584 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
585 override;
586 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
587 override;
588
591 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
592 void print(std::ostream& os) override;
593 bool is_equivalent(const Primitive& other) const override;
594
595 std::string lib_name() const {
596 return kernel_lib_;
597 }
598
599 private:
600 const std::vector<array> inputs_;
601 const std::vector<array> outputs_;
602 const std::vector<array> tape_;
603 const std::unordered_set<uintptr_t> constant_ids_;
604
605 std::string kernel_lib_;
606};
607
609 public:
610 explicit Concatenate(Stream stream, int axis)
611 : UnaryPrimitive(stream), axis_(axis) {}
612
613 void eval_cpu(const std::vector<array>& inputs, array& out) override;
614 void eval_gpu(const std::vector<array>& inputs, array& out) override;
615
619 bool is_equivalent(const Primitive& other) const override;
620 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
621 auto state() const {
622 return axis_;
623 }
624
625 private:
626 int axis_;
627};
628
629class Conjugate : public UnaryPrimitive {
630 public:
632
633 void eval_cpu(const std::vector<array>& inputs, array& out) override;
634 void eval_gpu(const std::vector<array>& inputs, array& out) override;
635
640};
641
643 public:
644 explicit Contiguous(Stream stream, bool allow_col_major)
645 : UnaryPrimitive(stream), allow_col_major_(allow_col_major) {}
646
647 void eval_cpu(const std::vector<array>& inputs, array& out) override;
648 void eval_gpu(const std::vector<array>& inputs, array& out) override;
649
654
655 bool is_equivalent(const Primitive& other) const override;
656
657 private:
658 bool allow_col_major_;
659};
660
662 public:
663 explicit Convolution(
665 const std::vector<int>& kernel_strides,
666 const std::vector<int>& padding,
667 const std::vector<int>& kernel_dilation,
668 const std::vector<int>& input_dilation,
669 const int groups = 1,
670 const bool flip = false)
672 padding_(padding),
673 kernel_strides_(kernel_strides),
674 kernel_dilation_(kernel_dilation),
675 input_dilation_(input_dilation),
676 groups_(groups),
677 flip_(flip) {}
678
679 void eval_cpu(const std::vector<array>& inputs, array& out) override;
680 void eval_gpu(const std::vector<array>& inputs, array& out) override;
681
682 std::vector<array> vjp(
683 const std::vector<array>& primals,
684 const std::vector<array>& cotangents,
685 const std::vector<int>& argnums,
686 const std::vector<array>& outputs) override;
687
689 bool is_equivalent(const Primitive& other) const override;
690 auto state() const {
691 return std::make_tuple(
692 padding_,
693 kernel_strides_,
694 kernel_dilation_,
695 input_dilation_,
696 groups_,
697 flip_);
698 }
699
700 private:
701 std::vector<int> padding_;
702 std::vector<int> kernel_strides_;
703 std::vector<int> kernel_dilation_;
704 std::vector<int> input_dilation_;
705 int groups_;
706 bool flip_;
707};
708
709class Copy : public UnaryPrimitive {
710 public:
712
713 void eval_cpu(const std::vector<array>& inputs, array& out) override;
714 void eval_gpu(const std::vector<array>& inputs, array& out) override;
715
721
722 private:
723 void eval(const std::vector<array>& inputs, array& out);
724};
725
726class Cos : public UnaryPrimitive {
727 public:
729
730 void eval_cpu(const std::vector<array>& inputs, array& out) override;
731 void eval_gpu(const std::vector<array>& inputs, array& out) override;
732
738};
739
740class Cosh : public UnaryPrimitive {
741 public:
743
744 void eval_cpu(const std::vector<array>& inputs, array& out) override;
745 void eval_gpu(const std::vector<array>& inputs, array& out) override;
746
752};
753
755 public:
758 int num_outputs,
759 std::function<std::vector<array>(
760 const std::vector<array>&,
761 const std::vector<array>&,
762 const std::vector<array>&)> vjp,
763 std::function<std::vector<array>(
764 const std::vector<array>&,
765 const std::vector<array>&,
766 const std::vector<int>&)> jvp,
767 std::function<std::pair<std::vector<array>, std::vector<int>>(
768 const std::vector<array>&,
769 const std::vector<int>&)> vmap)
770 : Primitive(stream),
771 num_outputs_(num_outputs),
772 vjp_fun_(std::move(vjp)),
773 jvp_fun_(std::move(jvp)),
774 vmap_fun_(std::move(vmap)) {}
775
776 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
777 override;
778 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
779 override;
780
784
785 private:
786 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
787
788 int num_outputs_;
789
790 std::function<std::vector<array>(
791 const std::vector<array>&,
792 const std::vector<array>&,
793 const std::vector<array>&)>
794 vjp_fun_;
795 std::function<std::vector<array>(
796 const std::vector<array>&,
797 const std::vector<array>&,
798 const std::vector<int>&)>
799 jvp_fun_;
800 std::function<std::pair<std::vector<array>, std::vector<int>>(
801 const std::vector<array>&,
802 const std::vector<int>&)>
803 vmap_fun_;
804};
805
806class Depends : public Primitive {
807 public:
809
810 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
811 override;
812 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
813 override;
814
815 std::vector<array> vjp(
816 const std::vector<array>& primals,
817 const std::vector<array>& cotan,
818 const std::vector<int>& argnums,
819 const std::vector<array>& outputs) override;
820
822
823 private:
824 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
825};
826
827class Divide : public UnaryPrimitive {
828 public:
830
831 void eval_cpu(const std::vector<array>& inputs, array& out) override;
832 void eval_gpu(const std::vector<array>& inputs, array& out) override;
833
839};
840
841class DivMod : public Primitive {
842 public:
844
845 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
846 override;
847 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
848 override;
849
854 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
855 return std::vector{inputs[0].shape(), inputs[0].shape()};
856 }
857};
858
859class Select : public UnaryPrimitive {
860 public:
862
863 void eval_cpu(const std::vector<array>& inputs, array& out) override;
864 void eval_gpu(const std::vector<array>& inputs, array& out) override;
865
871};
872
873class Remainder : public UnaryPrimitive {
874 public:
876
877 void eval_cpu(const std::vector<array>& inputs, array& out) override;
878 void eval_gpu(const std::vector<array>& inputs, array& out) override;
879
885};
886
887class Equal : public UnaryPrimitive {
888 public:
889 explicit Equal(Stream stream, bool equal_nan = false)
890 : UnaryPrimitive(stream), equal_nan_(equal_nan) {}
891
892 void eval_cpu(const std::vector<array>& inputs, array& out) override;
893 void eval_gpu(const std::vector<array>& inputs, array& out) override;
894
899
900 void print(std::ostream& os) override {
901 if (equal_nan_) {
902 os << "NaNEqual";
903 } else {
904 os << "Equal";
905 }
906 }
907 auto state() const {
908 return equal_nan_;
909 };
910
911 private:
912 bool equal_nan_;
913};
914
915class Erf : public UnaryPrimitive {
916 public:
918
919 void eval_cpu(const std::vector<array>& inputs, array& out) override;
920 void eval_gpu(const std::vector<array>& inputs, array& out) override;
921
927};
928
929class ErfInv : public UnaryPrimitive {
930 public:
932
933 void eval_cpu(const std::vector<array>& inputs, array& out) override;
934 void eval_gpu(const std::vector<array>& inputs, array& out) override;
935
941};
942
943class Exp : public UnaryPrimitive {
944 public:
946
947 void eval_cpu(const std::vector<array>& inputs, array& out) override;
948 void eval_gpu(const std::vector<array>& inputs, array& out) override;
949
955};
956
957class Expm1 : public UnaryPrimitive {
958 public:
960
961 void eval_cpu(const std::vector<array>& inputs, array& out) override;
962 void eval_gpu(const std::vector<array>& inputs, array& out) override;
963
968};
969
971 public:
972 explicit ExpandDims(Stream stream, std::vector<int> axes)
973 : UnaryPrimitive(stream), axes_(std::move(axes)) {}
974
975 void eval_cpu(const std::vector<array>& inputs, array& out) override;
976 void eval_gpu(const std::vector<array>& inputs, array& out) override;
977
981
982 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
983 bool is_equivalent(const Primitive& other) const override;
984
985 static Shape output_shape(const array& input, const std::vector<int>& axes);
986 auto state() const {
987 return axes_;
988 }
989
990 private:
991 void eval(const std::vector<array>& inputs, array& out);
992 std::vector<int> axes_;
993};
994
995class FFT : public UnaryPrimitive {
996 public:
997 explicit FFT(
999 const std::vector<size_t>& axes,
1000 bool inverse,
1001 bool real)
1002 : UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {}
1003
1004 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1005 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1006
1010
1011 bool is_equivalent(const Primitive& other) const override;
1012 auto state() const {
1013 return std::make_tuple(axes_, inverse_, real_);
1014 }
1015
1016 private:
1017 std::vector<size_t> axes_;
1018 bool inverse_;
1019 bool real_;
1020};
1021
1022class Flatten : public UnaryPrimitive {
1023 public:
1024 explicit Flatten(Stream stream, int start_axis, int end_axis)
1025 : UnaryPrimitive(stream), start_axis_(start_axis), end_axis_(end_axis) {}
1026
1027 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1028 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1029
1033 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1034 bool is_equivalent(const Primitive& other) const override;
1035
1036 static Shape output_shape(const array& input, int start_axis, int end_axis);
1037 auto state() const {
1038 return std::make_pair(start_axis_, end_axis_);
1039 }
1040
1041 private:
1042 int start_axis_;
1043 int end_axis_;
1044 void eval(const std::vector<array>& inputs, array& out);
1045};
1046
1047class Floor : public UnaryPrimitive {
1048 public:
1050
1051 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1052 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1053
1059};
1060
1061class Full : public UnaryPrimitive {
1062 public:
1064
1065 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1066 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1067
1072};
1073
1074class Gather : public UnaryPrimitive {
1075 public:
1076 explicit Gather(Stream stream, std::vector<int> axes, Shape slice_sizes)
1078 axes_(std::move(axes)),
1079 slice_sizes_(std::move(slice_sizes)) {}
1080
1081 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1082 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1083
1087 bool is_equivalent(const Primitive& other) const override;
1088 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1089 std::pair<std::vector<int>, std::vector<int>> state() const {
1090 return {axes_, slice_sizes_};
1091 }
1092
1093 private:
1094 std::vector<int> axes_;
1095 Shape slice_sizes_;
1096};
1097
1099 public:
1100 explicit GatherAxis(Stream stream, int axis)
1101 : UnaryPrimitive(stream), axis_(axis) {}
1102
1103 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1104 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1105
1109 bool is_equivalent(const Primitive& other) const override;
1110 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1111 auto state() const {
1112 return axis_;
1113 }
1114
1115 private:
1116 int axis_;
1117};
1118
1119class Greater : public UnaryPrimitive {
1120 public:
1122
1123 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1124 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1125
1131};
1132
1134 public:
1136
1137 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1138 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1139
1145};
1146
1147class Hadamard : public UnaryPrimitive {
1148 public:
1149 explicit Hadamard(Stream stream, float scale)
1150 : UnaryPrimitive(stream), scale_(scale) {}
1151
1152 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1153 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1154
1159
1160 bool is_equivalent(const Primitive& other) const override;
1161 auto state() const {
1162 return scale_;
1163 }
1164
1165 private:
1166 float scale_;
1167};
1168
1169class Imag : public UnaryPrimitive {
1170 public:
1172
1173 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1174 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1175
1181};
1182
1183class Less : public UnaryPrimitive {
1184 public:
1186
1187 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1188 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1189
1195};
1196
1198 public:
1200
1201 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1202 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1203
1209};
1210
1211class Load : public UnaryPrimitive {
1212 public:
1213 explicit Load(
1214 Stream stream,
1215 std::shared_ptr<io::Reader> reader,
1216 size_t offset,
1217 bool swap_endianness = false)
1219 reader_(std::move(reader)),
1220 offset_(offset),
1221 swap_endianness_(swap_endianness) {
1222 if (stream.device == Device::gpu) {
1223 io_stream();
1224 }
1225 }
1226
1227 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1228 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1229
1231
1232 private:
1233 Stream& io_stream() {
1234 static Stream io_stream = new_stream(Device::cpu);
1235 return io_stream;
1236 };
1237 std::shared_ptr<io::Reader> reader_;
1238 size_t offset_;
1239 bool swap_endianness_;
1240};
1241
1242class Log : public UnaryPrimitive {
1243 public:
1244 enum Base { two, ten, e };
1245
1246 explicit Log(Stream stream, Base base)
1247 : UnaryPrimitive(stream), base_(base) {}
1248
1249 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1250 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1251
1256
1257 Base state() const {
1258 return base_;
1259 };
1260
1261 void print(std::ostream& os) override {
1262 switch (base_) {
1263 case e:
1264 os << "Log";
1265 break;
1266 case two:
1267 os << "Log2";
1268 break;
1269 case ten:
1270 os << "Log10";
1271 break;
1272 }
1273 }
1274
1275 private:
1276 Base base_;
1277};
1278
1279class Log1p : public UnaryPrimitive {
1280 public:
1282
1283 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1284 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1285
1290};
1291
1293 public:
1295
1296 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1297 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1298
1304};
1305
1307 public:
1309
1310 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1311 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1312
1318};
1319
1321 public:
1323
1324 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1325 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1326
1332};
1333
1335 public:
1337
1338 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1339 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1340
1346};
1347
1348class Matmul : public UnaryPrimitive {
1349 public:
1351
1352 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1353 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1354
1359 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1360};
1361
1362class Maximum : public UnaryPrimitive {
1363 public:
1365
1366 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1367 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1368
1374};
1375
1376class Minimum : public UnaryPrimitive {
1377 public:
1379
1380 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1381 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1382
1388};
1389
1390class Multiply : public UnaryPrimitive {
1391 public:
1393
1394 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1395 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1396
1402};
1403
1404class Negative : public UnaryPrimitive {
1405 public:
1407
1408 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1409 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1410
1416};
1417
1418class NotEqual : public UnaryPrimitive {
1419 public:
1421
1422 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1423 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1424
1430};
1431
1433 public:
1435 Stream stream,
1436 std::vector<int> axes,
1437 bool inverted,
1438 Dtype dtype)
1440 axes_(std::move(axes)),
1441 inverted_(inverted),
1442 dtype_(dtype) {}
1443
1444 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1445 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1446
1449 bool is_equivalent(const Primitive& other) const override;
1450 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
1451 return {{}};
1452 }
1453 std::tuple<std::vector<int>, bool, Dtype> state() const {
1454 return {axes_, inverted_, dtype_};
1455 }
1456
1457 private:
1458 std::vector<int> axes_;
1459 bool inverted_;
1460 Dtype dtype_;
1461
1462 void eval(const std::vector<array>& inputs, array& out);
1463};
1464
1465class Pad : public UnaryPrimitive {
1466 public:
1467 explicit Pad(
1468 Stream stream,
1469 const std::vector<int>& axes,
1470 const Shape& low_pad_size,
1471 const Shape& high_pad_size)
1473 axes_(axes),
1474 low_pad_size_(low_pad_size),
1475 high_pad_size_(high_pad_size) {}
1476
1477 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1478 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1479
1483 bool is_equivalent(const Primitive& other) const override;
1484 auto state() const {
1485 return std::make_tuple(axes_, low_pad_size_, high_pad_size_);
1486 }
1487
1488 private:
1489 std::vector<int> axes_;
1490 Shape low_pad_size_;
1491 Shape high_pad_size_;
1492};
1493
1495 public:
1496 explicit Partition(Stream stream, int kth, int axis)
1497 : UnaryPrimitive(stream), kth_(kth), axis_(axis) {}
1498
1499 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1500 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1501
1506 bool is_equivalent(const Primitive& other) const override;
1507 auto state() const {
1508 return std::make_pair(kth_, axis_);
1509 };
1510
1511 private:
1512 int kth_;
1513 int axis_;
1514};
1515
1516class Power : public UnaryPrimitive {
1517 public:
1519
1520 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1521 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1522
1528};
1529
1531 public:
1533 Stream stream,
1534 int group_size,
1535 int bits,
1536 bool transpose)
1538 group_size_(group_size),
1539 bits_(bits),
1540 transpose_(transpose) {}
1541
1542 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1543 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1544
1548 bool is_equivalent(const Primitive& other) const override;
1549 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1550 auto state() const {
1551 return std::make_tuple(group_size_, bits_, transpose_);
1552 }
1553
1554 private:
1555 int group_size_;
1556 int bits_;
1557 bool transpose_;
1558};
1559
1561 public:
1562 explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose)
1564 group_size_(group_size),
1565 bits_(bits),
1566 transpose_(transpose) {}
1567
1568 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1569 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1570
1574 bool is_equivalent(const Primitive& other) const override;
1575 auto state() const {
1576 return std::make_tuple(group_size_, bits_, transpose_);
1577 }
1578
1579 private:
1580 int group_size_;
1581 int bits_;
1582 bool transpose_;
1583};
1584
1586 public:
1587 explicit RandomBits(Stream stream, const Shape& shape, int width)
1588 : UnaryPrimitive(stream), shape_(shape), width_(width) {}
1589
1590 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1591 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1592
1595 bool is_equivalent(const Primitive& other) const override;
1596 std::pair<std::vector<int>, int> state() const {
1597 return {shape_, width_};
1598 };
1599
1600 private:
1601 Shape shape_;
1602 int width_;
1603};
1604
1605class Real : public UnaryPrimitive {
1606 public:
1608
1609 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1610 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1611
1617};
1618
1619class Reshape : public UnaryPrimitive {
1620 public:
1621 explicit Reshape(Stream stream, const Shape& shape)
1622 : UnaryPrimitive(stream), shape_(shape) {}
1623
1624 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1625 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1626
1630 bool is_equivalent(const Primitive& other) const override;
1631 std::vector<int> state() const {
1632 return shape_;
1633 };
1634 static Shape output_shape(const array& input, Shape shape);
1635 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1636
1637 private:
1638 Shape shape_;
1639};
1640
1641class Reduce : public UnaryPrimitive {
1642 public:
1643 enum ReduceType { And, Or, Sum, Prod, Min, Max };
1644
1645 explicit Reduce(
1646 Stream stream,
1647 ReduceType reduce_type,
1648 const std::vector<int>& axes)
1649 : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}
1650
1651 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1652 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1653
1655
1656 std::vector<array> vjp(
1657 const std::vector<array>& primals,
1658 const std::vector<array>& cotangents,
1659 const std::vector<int>& argnums,
1660 const std::vector<array>& outputs) override;
1661
1662 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1663
1664 void print(std::ostream& os) override {
1665 switch (reduce_type_) {
1666 case And:
1667 os << "And";
1668 break;
1669 case Or:
1670 os << "Or";
1671 break;
1672 case Sum:
1673 os << "Sum";
1674 break;
1675 case Prod:
1676 os << "Prod";
1677 break;
1678 case Min:
1679 os << "Min";
1680 break;
1681 case Max:
1682 os << "Max";
1683 break;
1684 }
1685 }
1686 bool is_equivalent(const Primitive& other) const override;
1687 std::pair<ReduceType, std::vector<int>> state() const {
1688 return {reduce_type_, axes_};
1689 };
1690
1691 private:
1692 ReduceType reduce_type_;
1693 std::vector<int> axes_;
1694};
1695
1696class Round : public UnaryPrimitive {
1697 public:
1699
1700 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1701 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1702
1708};
1709
1710class Scan : public UnaryPrimitive {
1711 public:
1713
1714 explicit Scan(
1715 Stream stream,
1716 ReduceType reduce_type,
1717 int axis,
1718 bool reverse,
1719 bool inclusive)
1721 reduce_type_(reduce_type),
1722 axis_(axis),
1723 reverse_(reverse),
1724 inclusive_(inclusive) {}
1725
1726 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1727 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1728
1731
1732 void print(std::ostream& os) override {
1733 os << "Cum";
1734 switch (reduce_type_) {
1735 case Sum:
1736 os << "Sum";
1737 break;
1738 case Prod:
1739 os << "Prod";
1740 break;
1741 case Min:
1742 os << "Min";
1743 break;
1744 case Max:
1745 os << "Max";
1746 break;
1747 }
1748 }
1749 bool is_equivalent(const Primitive& other) const override;
1750 auto state() const {
1751 return std::make_tuple(reduce_type_, axis_, reverse_, inclusive_);
1752 }
1753
1754 private:
1755 ReduceType reduce_type_;
1756 int axis_;
1757 bool reverse_;
1758 bool inclusive_;
1759};
1760
1761class Scatter : public UnaryPrimitive {
1762 public:
1764
1765 explicit Scatter(
1766 Stream stream,
1767 ReduceType reduce_type,
1768 const std::vector<int>& axes)
1769 : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}
1770
1771 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1772 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1773
1776
1777 void print(std::ostream& os) override {
1778 os << "Scatter";
1779 switch (reduce_type_) {
1780 case Sum:
1781 os << " Sum";
1782 break;
1783 case Prod:
1784 os << " Prod";
1785 break;
1786 case Min:
1787 os << " Min";
1788 break;
1789 case Max:
1790 os << " Max";
1791 break;
1792 case None:
1793 break;
1794 }
1795 }
1796 bool is_equivalent(const Primitive& other) const override;
1797 std::pair<ReduceType, std::vector<int>> state() const {
1798 return {reduce_type_, axes_};
1799 };
1800
1801 private:
1802 ReduceType reduce_type_;
1803 std::vector<int> axes_;
1804};
1805
1807 public:
1809
1810 explicit ScatterAxis(Stream stream, ReduceType reduce_type, int axis)
1811 : UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {}
1812
1813 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1814 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1815
1818
1819 void print(std::ostream& os) override {
1820 os << "ScatterAxis";
1821 switch (reduce_type_) {
1822 case Sum:
1823 os << " Sum";
1824 break;
1825 case None:
1826 break;
1827 }
1828 }
1829
1830 bool is_equivalent(const Primitive& other) const override;
1831 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1832 std::pair<ReduceType, int> state() const {
1833 return {reduce_type_, axis_};
1834 }
1835
1836 private:
1837 ReduceType reduce_type_;
1838 int axis_;
1839};
1840
1841class Sigmoid : public UnaryPrimitive {
1842 public:
1844
1845 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1846 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1847
1853};
1854
1855class Sign : public UnaryPrimitive {
1856 public:
1858
1859 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1860 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1861
1867};
1868
1869class Sin : public UnaryPrimitive {
1870 public:
1872
1873 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1874 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1875
1881};
1882
1883class Sinh : public UnaryPrimitive {
1884 public:
1886
1887 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1888 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1889
1895};
1896
1897class Slice : public UnaryPrimitive {
1898 public:
1899 explicit Slice(
1900 Stream stream,
1901 const Shape& start_indices,
1902 const Shape& end_indices,
1903 const Shape& strides)
1905 start_indices_(start_indices),
1906 end_indices_(end_indices),
1907 strides_(strides) {}
1908
1909 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1910 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1911
1915 bool is_equivalent(const Primitive& other) const override;
1916 auto state() const {
1917 return std::make_tuple(start_indices_, end_indices_, strides_);
1918 }
1919
1920 private:
1921 Shape start_indices_;
1922 Shape end_indices_;
1923 Shape strides_;
1924};
1925
1927 public:
1928 explicit SliceUpdate(
1929 Stream stream,
1930 const Shape& start_indices,
1931 const Shape& end_indices,
1932 const Shape& strides)
1934 start_indices_(start_indices),
1935 end_indices_(end_indices),
1936 strides_(strides) {}
1937
1938 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1939 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1940
1944 bool is_equivalent(const Primitive& other) const override;
1946 auto state() const {
1947 return std::make_tuple(start_indices_, end_indices_, strides_);
1948 }
1949
1950 private:
1951 Shape start_indices_;
1952 Shape end_indices_;
1953 Shape strides_;
1954};
1955
1957 public:
1958 explicit DynamicSlice(Stream stream, std::vector<int> axes, Shape slice_size)
1960 axes_(std::move(axes)),
1961 slice_size_(std::move(slice_size)) {}
1962
1963 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1964 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1965
1969 bool is_equivalent(const Primitive& other) const override;
1970 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1971 auto state() const {
1972 return std::make_pair(axes_, slice_size_);
1973 }
1974
1975 private:
1976 std::vector<int> axes_;
1977 Shape slice_size_;
1978};
1979
1981 public:
1982 explicit DynamicSliceUpdate(Stream stream, std::vector<int> axes)
1983 : UnaryPrimitive(stream), axes_(std::move(axes)) {}
1984
1985 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1986 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1987
1991 bool is_equivalent(const Primitive& other) const override;
1993 auto state() const {
1994 return axes_;
1995 }
1996
1997 private:
1998 std::vector<int> axes_;
1999};
2000
2001class Softmax : public UnaryPrimitive {
2002 public:
2003 explicit Softmax(Stream stream, bool precise)
2004 : UnaryPrimitive(stream), precise_(precise) {}
2005
2006 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2007 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2008
2013
2014 bool is_equivalent(const Primitive& other) const override;
2015 auto state() const {
2016 return precise_;
2017 };
2018
2019 private:
2020 bool precise_;
2021};
2022
2023class Sort : public UnaryPrimitive {
2024 public:
2025 explicit Sort(Stream stream, int axis)
2026 : UnaryPrimitive(stream), axis_(axis) {}
2027
2028 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2029 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2030
2035 bool is_equivalent(const Primitive& other) const override;
2036 auto state() const {
2037 return axis_;
2038 }
2039
2040 private:
2041 int axis_;
2042};
2043
2044class Split : public Primitive {
2045 public:
2046 explicit Split(Stream stream, const Shape& indices, int axis)
2047 : Primitive(stream), indices_(indices), axis_(axis) {}
2048
2049 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2050 override;
2051 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2052 override;
2053
2057 bool is_equivalent(const Primitive& other) const override;
2058 std::pair<std::vector<int>, int> state() const {
2059 return {indices_, axis_};
2060 };
2061
2062 private:
2063 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
2064
2065 Shape indices_;
2066 int axis_;
2067};
2068
2069class Square : public UnaryPrimitive {
2070 public:
2072
2073 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2074 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2075
2081};
2082
2083class Sqrt : public UnaryPrimitive {
2084 public:
2085 explicit Sqrt(Stream stream, bool recip = false)
2086 : UnaryPrimitive(stream), recip_(recip) {}
2087
2088 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2089 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2090
2094 bool is_equivalent(const Primitive& other) const override;
2095 auto state() const {
2096 return recip_;
2097 }
2098
2099 void print(std::ostream& os) override {
2100 if (recip_) {
2101 os << "Rsqrt";
2102 } else {
2103 os << "Sqrt";
2104 }
2105 }
2106
2107 private:
2108 bool recip_;
2109};
2110
2112 public:
2114
2115 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2116 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2117
2122
2123 private:
2124 void eval(const std::vector<array>& inputs, array& out);
2125};
2126
2127class Subtract : public UnaryPrimitive {
2128 public:
2130
2131 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2132 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2133
2139};
2140
2141class Squeeze : public UnaryPrimitive {
2142 public:
2143 explicit Squeeze(Stream stream, std::vector<int> axes)
2144 : UnaryPrimitive(stream), axes_(std::move(axes)) {}
2145
2146 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2147 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2148
2152
2153 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
2154 bool is_equivalent(const Primitive& other) const override;
2155
2156 static Shape output_shape(const array& input, const std::vector<int>& axes);
2157 auto state() const {
2158 return axes_;
2159 };
2160
2161 private:
2162 void eval(const std::vector<array>& inputs, array& out);
2163 std::vector<int> axes_;
2164};
2165
2166class Tan : public UnaryPrimitive {
2167 public:
2169
2170 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2171 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2172
2178};
2179
2180class Tanh : public UnaryPrimitive {
2181 public:
2183
2184 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2185 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2186
2192};
2193
2195 public:
2196 explicit Unflatten(Stream stream, int axis, Shape shape)
2197 : UnaryPrimitive(stream), axis_(axis), shape_(std::move(shape)) {}
2198
2199 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2200 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2201
2205 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
2206 bool is_equivalent(const Primitive& other) const override;
2207
2208 static Shape output_shape(const array& input, int axis, const Shape& shape);
2209 auto state() const {
2210 return std::make_pair(axis_, shape_);
2211 }
2212
2213 private:
2214 int axis_;
2215 Shape shape_;
2216 void eval(const std::vector<array>& inputs, array& out);
2217};
2218
2219class View : public UnaryPrimitive {
2220 public:
2221 explicit View(Stream stream, Dtype dtype)
2222 : UnaryPrimitive(stream), dtype_(dtype) {}
2223
2224 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2225 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2226
2228 void print(std::ostream& os) override;
2229 bool is_equivalent(const Primitive& other) const override;
2230 auto state() const {
2231 return dtype_;
2232 }
2233
2234 private:
2235 Dtype dtype_;
2236};
2237
2239 public:
2240 explicit Transpose(Stream stream, const std::vector<int>& axes)
2241 : UnaryPrimitive(stream), axes_(axes) {}
2242
2243 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2244 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2245
2249 bool is_equivalent(const Primitive& other) const override;
2250 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
2251 std::vector<int> state() const {
2252 return axes_;
2253 };
2254
2255 private:
2256 std::vector<int> axes_;
2257
2258 void eval(const std::vector<array>& inputs, array& out);
2259};
2260
2261/* QR Factorization primitive. */
2262class QRF : public Primitive {
2263 public:
2265
2266 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2267 override;
2268 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2269 override;
2270
2272};
2273
2274/* SVD primitive. */
2275class SVD : public Primitive {
2276 public:
2278
2279 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2280 override;
2281 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2282 override;
2283
2286};
2287
2288/* Matrix inversion primitive. */
2289class Inverse : public UnaryPrimitive {
2290 public:
2291 explicit Inverse(Stream stream, bool tri, bool upper)
2292 : UnaryPrimitive(stream), tri_(tri), upper_(upper) {}
2293
2294 void eval_cpu(const std::vector<array>& inputs, array& output) override;
2295 void eval_gpu(const std::vector<array>& inputs, array& output) override;
2296
2299 auto state() const {
2300 return std::make_pair(tri_, upper_);
2301 }
2302
2303 private:
2304 bool tri_;
2305 bool upper_;
2306};
2307
2308class Cholesky : public UnaryPrimitive {
2309 public:
2310 explicit Cholesky(Stream stream, bool upper)
2311 : UnaryPrimitive(stream), upper_(upper) {}
2312
2313 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2314 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2315 auto state() const {
2316 return upper_;
2317 }
2318
2321
2322 private:
2323 bool upper_;
2324};
2325
2326class Eigh : public Primitive {
2327 public:
2328 explicit Eigh(Stream stream, std::string uplo, bool compute_eigenvectors)
2329 : Primitive(stream),
2330 uplo_(std::move(uplo)),
2331 compute_eigenvectors_(compute_eigenvectors) {}
2332
2333 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2334 override;
2335 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2336 override;
2337
2340
2341 std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
2342
2343 bool is_equivalent(const Primitive& other) const override;
2344 auto state() const {
2345 return std::make_pair(uplo_, compute_eigenvectors_);
2346 }
2347
2348 private:
2349 std::string uplo_;
2350 bool compute_eigenvectors_;
2351};
2352
2353} // namespace mlx::core
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Abs(Stream stream)
Definition primitives.h:156
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Add(Stream stream)
Definition primitives.h:170
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::pair< float, float > state() const
Definition primitives.h:195
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
AddMM(Stream stream, float alpha, float beta)
Definition primitives.h:184
Arange(Stream stream, double start, double stop, double step)
Definition primitives.h:206
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::tuple< double, double, double > state() const
Definition primitives.h:215
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCos(Stream stream)
Definition primitives.h:227
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCosh(Stream stream)
Definition primitives.h:241
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcSin(Stream stream)
Definition primitives.h:255
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcSinh(Stream stream)
Definition primitives.h:269
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTan2(Stream stream)
Definition primitives.h:297
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcTan(Stream stream)
Definition primitives.h:283
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTanh(Stream stream)
Definition primitives.h:311
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::pair< int, int > state() const
Definition primitives.h:336
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArgPartition(Stream stream, int kth, int axis)
Definition primitives.h:325
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
ReduceType
Definition primitives.h:347
@ ArgMin
Definition primitives.h:348
@ ArgMax
Definition primitives.h:349
ArgReduce(Stream stream, ReduceType reduce_type, int axis)
Definition primitives.h:352
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::pair< ReduceType, int > state() const
Definition primitives.h:363
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
ArgSort(Stream stream, int axis)
Definition primitives.h:374
int state() const
Definition primitives.h:384
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:427
AsStrided(Stream stream, Shape shape, Strides strides, size_t offset)
Definition primitives.h:415
void eval_gpu(const std::vector< array > &inputs, array &out) override
AsType(Stream stream, Dtype dtype)
Definition primitives.h:394
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Dtype state() const
Definition primitives.h:405
void eval_cpu(const std::vector< array > &inputs, array &out) override
BitwiseBinary(Stream stream, Op op)
Definition primitives.h:443
void eval_cpu(const std::vector< array > &inputs, array &out) override
void print(std::ostream &os) override
Print the primitive.
Op
Definition primitives.h:441
@ RightShift
Definition primitives.h:441
@ Or
Definition primitives.h:441
@ LeftShift
Definition primitives.h:441
@ And
Definition primitives.h:441
@ Xor
Definition primitives.h:441
auto state() const
Definition primitives.h:454
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
auto state() const
Definition primitives.h:478
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
BlockMaskedMM(Stream stream, int block_size)
Definition primitives.h:464
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
BroadcastAxes(Stream stream, std::vector< int > ignore_axes={})
Definition primitives.h:505
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:519
void eval_cpu(const std::vector< array > &inputs, array &out) override
static Shape output_shape(const std::vector< array > &inputs, const std::vector< int > &ignore_axes)
Broadcast(Stream stream, const Shape &shape)
Definition primitives.h:530
static Shape output_shape(const std::vector< array > &inputs)
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< int > state() const
Definition primitives.h:542
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Ceil(Stream stream)
Definition primitives.h:554
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2315
Cholesky(Stream stream, bool upper)
Definition primitives.h:2310
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void print(std::ostream &os) override
Print the primitive.
Compiled(Stream stream, std::vector< array > inputs, std::vector< array > outputs, std::vector< array > tape, std::unordered_set< uintptr_t > constant_ids)
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
std::string lib_name() const
Definition primitives.h:595
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:621
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Concatenate(Stream stream, int axis)
Definition primitives.h:610
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Conjugate(Stream stream)
Definition primitives.h:631
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Contiguous(Stream stream, bool allow_col_major)
Definition primitives.h:644
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
Convolution(Stream stream, const std::vector< int > &kernel_strides, const std::vector< int > &padding, const std::vector< int > &kernel_dilation, const std::vector< int > &input_dilation, const int groups=1, const bool flip=false)
Definition primitives.h:663
auto state() const
Definition primitives.h:690
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
Copy(Stream stream)
Definition primitives.h:711
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Cos(Stream stream)
Definition primitives.h:728
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Cosh(Stream stream)
Definition primitives.h:742
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
virtual std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes) override
The primitive must know how to vectorize itself across the given axes.
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
The Jacobian-vector product.
CustomTransforms(Stream stream, int num_outputs, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> vjp, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< int > &)> jvp, std::function< std::pair< std::vector< array >, std::vector< int > >(const std::vector< array > &, const std::vector< int > &)> vmap)
Definition primitives.h:756
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotan, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Depends(Stream stream)
Definition primitives.h:808
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:854
DivMod(Stream stream)
Definition primitives.h:843
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Divide(Stream stream)
Definition primitives.h:829
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
DynamicSlice(Stream stream, std::vector< int > axes, Shape slice_size)
Definition primitives.h:1958
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1971
auto state() const
Definition primitives.h:1993
DynamicSliceUpdate(Stream stream, std::vector< int > axes)
Definition primitives.h:1982
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
auto state() const
Definition primitives.h:2344
Eigh(Stream stream, std::string uplo, bool compute_eigenvectors)
Definition primitives.h:2328
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:900
Equal(Stream stream, bool equal_nan=false)
Definition primitives.h:889
auto state() const
Definition primitives.h:907
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Erf(Stream stream)
Definition primitives.h:917
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ErfInv(Stream stream)
Definition primitives.h:931
void eval_cpu(const std::vector< array > &inputs, array &out) override
Exp(Stream stream)
Definition primitives.h:945
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
static Shape output_shape(const array &input, const std::vector< int > &axes)
auto state() const
Definition primitives.h:986
void eval_gpu(const std::vector< array > &inputs, array &out) override
ExpandDims(Stream stream, std::vector< int > axes)
Definition primitives.h:972
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Expm1(Stream stream)
Definition primitives.h:959
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
FFT(Stream stream, const std::vector< size_t > &axes, bool inverse, bool real)
Definition primitives.h:997
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1012
static Shape output_shape(const array &input, int start_axis, int end_axis)
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
Flatten(Stream stream, int start_axis, int end_axis)
Definition primitives.h:1024
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1037
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Floor(Stream stream)
Definition primitives.h:1049
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Full(Stream stream)
Definition primitives.h:1063
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
GatherAxis(Stream stream, int axis)
Definition primitives.h:1100
auto state() const
Definition primitives.h:1111
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::pair< std::vector< int >, std::vector< int > > state() const
Definition primitives.h:1089
Gather(Stream stream, std::vector< int > axes, Shape slice_sizes)
Definition primitives.h:1076
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, array &out) override
GatherMM(Stream stream)
Definition primitives.h:488
auto state() const
Definition primitives.h:1575
GatherQMM(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1562
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
GreaterEqual(Stream stream)
Definition primitives.h:1135
void eval_gpu(const std::vector< array > &inputs, array &out) override
Greater(Stream stream)
Definition primitives.h:1121
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
Hadamard(Stream stream, float scale)
Definition primitives.h:1149
auto state() const
Definition primitives.h:1161
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Imag(Stream stream)
Definition primitives.h:1171
void eval_gpu(const std::vector< array > &inputs, array &output) override
Inverse(Stream stream, bool tri, bool upper)
Definition primitives.h:2291
auto state() const
Definition primitives.h:2299
void eval_cpu(const std::vector< array > &inputs, array &output) override
LessEqual(Stream stream)
Definition primitives.h:1199
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Less(Stream stream)
Definition primitives.h:1185
void eval_gpu(const std::vector< array > &inputs, array &out) override
Load(Stream stream, std::shared_ptr< io::Reader > reader, size_t offset, bool swap_endianness=false)
Definition primitives.h:1213
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Log1p(Stream stream)
Definition primitives.h:1281
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogAddExp(Stream stream)
Definition primitives.h:1336
Base
Definition primitives.h:1244
@ ten
Definition primitives.h:1244
@ two
Definition primitives.h:1244
@ e
Definition primitives.h:1244
Log(Stream stream, Base base)
Definition primitives.h:1246
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1261
Base state() const
Definition primitives.h:1257
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalAnd(Stream stream)
Definition primitives.h:1308
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalNot(Stream stream)
Definition primitives.h:1294
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
LogicalOr(Stream stream)
Definition primitives.h:1322
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Matmul(Stream stream)
Definition primitives.h:1350
Maximum(Stream stream)
Definition primitives.h:1364
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Minimum(Stream stream)
Definition primitives.h:1378
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Multiply(Stream stream)
Definition primitives.h:1392
void eval_gpu(const std::vector< array > &inputs, array &out) override
Negative(Stream stream)
Definition primitives.h:1406
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
NotEqual(Stream stream)
Definition primitives.h:1420
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:1450
NumberOfElements(Stream stream, std::vector< int > axes, bool inverted, Dtype dtype)
Definition primitives.h:1434
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::tuple< std::vector< int >, bool, Dtype > state() const
Definition primitives.h:1453
auto state() const
Definition primitives.h:1484
Pad(Stream stream, const std::vector< int > &axes, const Shape &low_pad_size, const Shape &high_pad_size)
Definition primitives.h:1467
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Partition(Stream stream, int kth, int axis)
Definition primitives.h:1496
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
auto state() const
Definition primitives.h:1507
void eval_cpu(const std::vector< array > &inputs, array &out) override
Power(Stream stream)
Definition primitives.h:1518
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:48
virtual void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
virtual std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
The vector-Jacobian product.
virtual ~Primitive()=default
Primitive(const Primitive &other)=delete
Primitive(Primitive &&other)=delete
const Stream & stream()
The stream the primitive will run on.
Definition primitives.h:58
Primitive & operator=(Primitive &&other)=delete
virtual bool is_equivalent(const Primitive &other) const
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:107
Primitive & operator=(const Primitive &other)=delete
const Device & device()
The device the primitive will run on.
Definition primitives.h:53
virtual std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
The Jacobian-vector product.
virtual std::vector< Shape > output_shapes(const std::vector< array > &inputs)
Get the output shapes of the primitive.
virtual std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes)
The primitive must know how to vectorize itself across the given axes.
virtual void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
virtual void print(std::ostream &os)=0
Print the primitive.
Primitive(Stream stream)
Definition primitives.h:50
QRF(Stream stream)
Definition primitives.h:2264
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
QuantizedMatmul(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1532
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1550
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::pair< std::vector< int >, int > state() const
Definition primitives.h:1596
RandomBits(Stream stream, const Shape &shape, int width)
Definition primitives.h:1587
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Real(Stream stream)
Definition primitives.h:1607
Reduce(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1645
ReduceType
Definition primitives.h:1643
@ Min
Definition primitives.h:1643
@ Or
Definition primitives.h:1643
@ Max
Definition primitives.h:1643
@ And
Definition primitives.h:1643
@ Sum
Definition primitives.h:1643
@ Prod
Definition primitives.h:1643
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1664
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::pair< ReduceType, std::vector< int > > state() const
Definition primitives.h:1687
Remainder(Stream stream)
Definition primitives.h:875
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
static Shape output_shape(const array &input, Shape shape)
void eval_gpu(const std::vector< array > &inputs, array &out) override
Reshape(Stream stream, const Shape &shape)
Definition primitives.h:1621
std::vector< int > state() const
Definition primitives.h:1631
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Round(Stream stream)
Definition primitives.h:1698
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
SVD(Stream stream)
Definition primitives.h:2277
void eval_cpu(const std::vector< array > &inputs, array &out) override
ReduceType
Definition primitives.h:1712
@ Prod
Definition primitives.h:1712
@ Min
Definition primitives.h:1712
@ Max
Definition primitives.h:1712
@ Sum
Definition primitives.h:1712
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
auto state() const
Definition primitives.h:1750
Scan(Stream stream, ReduceType reduce_type, int axis, bool reverse, bool inclusive)
Definition primitives.h:1714
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1732
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::pair< ReduceType, int > state() const
Definition primitives.h:1832
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1819
void eval_gpu(const std::vector< array > &inputs, array &out) override
ScatterAxis(Stream stream, ReduceType reduce_type, int axis)
Definition primitives.h:1810
ReduceType
Definition primitives.h:1808
@ Sum
Definition primitives.h:1808
@ None
Definition primitives.h:1808
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::pair< ReduceType, std::vector< int > > state() const
Definition primitives.h:1797
ReduceType
Definition primitives.h:1763
@ Sum
Definition primitives.h:1763
@ Max
Definition primitives.h:1763
@ Prod
Definition primitives.h:1763
@ None
Definition primitives.h:1763
@ Min
Definition primitives.h:1763
void eval_cpu(const std::vector< array > &inputs, array &out) override
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1777
void eval_gpu(const std::vector< array > &inputs, array &out) override
Scatter(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1765
void eval_gpu(const std::vector< array > &inputs, array &out) override
Select(Stream stream)
Definition primitives.h:861
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sigmoid(Stream stream)
Definition primitives.h:1843
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Sign(Stream stream)
Definition primitives.h:1857
Sin(Stream stream)
Definition primitives.h:1871
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sinh(Stream stream)
Definition primitives.h:1885
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1916
Slice(Stream stream, const Shape &start_indices, const Shape &end_indices, const Shape &strides)
Definition primitives.h:1899
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
SliceUpdate(Stream stream, const Shape &start_indices, const Shape &end_indices, const Shape &strides)
Definition primitives.h:1928
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:1946
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Softmax(Stream stream, bool precise)
Definition primitives.h:2003
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2015
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2036
Sort(Stream stream, int axis)
Definition primitives.h:2025
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
std::pair< std::vector< int >, int > state() const
Definition primitives.h:2058
Split(Stream stream, const Shape &indices, int axis)
Definition primitives.h:2046
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
auto state() const
Definition primitives.h:2095
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sqrt(Stream stream, bool recip=false)
Definition primitives.h:2085
void eval_gpu(const std::vector< array > &inputs, array &out) override
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:2099
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Square(Stream stream)
Definition primitives.h:2071
Squeeze(Stream stream, std::vector< int > axes)
Definition primitives.h:2143
auto state() const
Definition primitives.h:2157
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
void eval_cpu(const std::vector< array > &inputs, array &out) override
static Shape output_shape(const array &input, const std::vector< int > &axes)
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
StopGradient(Stream stream)
Definition primitives.h:2113
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Subtract(Stream stream)
Definition primitives.h:2129
Tan(Stream stream)
Definition primitives.h:2168
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Tanh(Stream stream)
Definition primitives.h:2182
void eval_cpu(const std::vector< array > &inputs, array &out) override
Transpose(Stream stream, const std::vector< int > &axes)
Definition primitives.h:2240
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< int > state() const
Definition primitives.h:2251
void eval_gpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:126
UnaryPrimitive & operator=(const UnaryPrimitive &other)=delete
UnaryPrimitive(Stream stream)
An abstract base class for a primitive with a single output.
Definition primitives.h:131
virtual void eval_gpu(const std::vector< array > &inputs, array &output)=0
UnaryPrimitive(UnaryPrimitive &&other)=delete
virtual void eval_cpu(const std::vector< array > &inputs, array &output)=0
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:141
UnaryPrimitive(const UnaryPrimitive &other)=delete
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:136
UnaryPrimitive & operator=(UnaryPrimitive &&other)=delete
virtual ~UnaryPrimitive()=default
std::vector< Shape > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Unflatten(Stream stream, int axis, Shape shape)
Definition primitives.h:2196
static Shape output_shape(const array &input, int axis, const Shape &shape)
void eval_cpu(const std::vector< array > &inputs, array &out) override
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2209
void eval_cpu(const std::vector< array > &inputs, array &out) override
auto state() const
Definition primitives.h:2230
void print(std::ostream &os) override
Print the primitive.
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
View(Stream stream, Dtype dtype)
Definition primitives.h:2221
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition array.h:24
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
array tri(int n, int m, int k, Dtype type, StreamOrDevice s={})
array transpose(const array &a, std::vector< int > axes, StreamOrDevice s={})
Permutes the dimensions according to the given axes.
array real(const array &a, StreamOrDevice s={})
Definition allocator.h:7
std::vector< ShapeElem > Shape
Definition array.h:21
Stream new_stream(Device d)
Make a new stream on the given device.
std::vector< int64_t > Strides
Definition array.h:22
void eval(std::vector< array > outputs)
#define DEFINE_DEFAULT_IS_EQUIVALENT()
Definition primitives.h:34
#define DEFINE_PRINT(PRIMITIVE)
Definition primitives.h:29
#define DEFINE_INPUT_OUTPUT_SHAPE()
Definition primitives.h:39
#define DEFINE_GRADS()
Definition primitives.h:17
#define DEFINE_VMAP()
Definition primitives.h:12
Definition device.h:7
static constexpr DeviceType gpu
Definition device.h:14
static constexpr DeviceType cpu
Definition device.h:13
Definition dtype.h:13
Definition stream.h:9
Device device
Definition stream.h:11