MLX
Loading...
Searching...
No Matches
primitives.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <unordered_set>
6
7#include "mlx/array.h"
8#include "mlx/device.h"
9#include "mlx/io/load.h"
10#include "mlx/stream.h"
11
12#define DEFINE_VMAP() \
13 virtual std::pair<std::vector<array>, std::vector<int>> vmap( \
14 const std::vector<array>& inputs, const std::vector<int>& axes) \
15 override;
16
17#define DEFINE_GRADS() \
18 std::vector<array> jvp( \
19 const std::vector<array>& primals, \
20 const std::vector<array>& tangents, \
21 const std::vector<int>& argnums) override; \
22 \
23 std::vector<array> vjp( \
24 const std::vector<array>& primals, \
25 const std::vector<array>& cotangents, \
26 const std::vector<int>& argnums, \
27 const std::vector<array>& outputs) override;
28
29#define DEFINE_PRINT(PRIMITIVE) \
30 void print(std::ostream& os) override { \
31 os << #PRIMITIVE; \
32 }
33
34#define DEFINE_DEFAULT_IS_EQUIVALENT() \
35 bool is_equivalent(const Primitive& other) const override { \
36 return true; \
37 }
38
39#define DEFINE_INPUT_OUTPUT_SHAPE() \
40 std::vector<std::vector<int>> output_shapes( \
41 const std::vector<array>& inputs) override { \
42 return {inputs[0].shape()}; \
43 }
44
45namespace mlx::core {
46
47// Abstract base class
48class Primitive {
49 public:
50 explicit Primitive(Stream stream) : stream_(stream) {}
51
53 const Device& device() {
54 return stream().device;
55 }
56
58 const Stream& stream() {
59 return stream_;
60 }
61
69 virtual void eval_cpu(
70 const std::vector<array>& inputs,
71 std::vector<array>& outputs) = 0;
72 virtual void eval_gpu(
73 const std::vector<array>& inputs,
74 std::vector<array>& outputs) = 0;
75
79 virtual std::vector<array> jvp(
80 const std::vector<array>& primals,
81 const std::vector<array>& tangents,
82 const std::vector<int>& argnums);
83
87 virtual std::vector<array> vjp(
88 const std::vector<array>& primals,
89 const std::vector<array>& cotangents,
90 const std::vector<int>& argnums,
91 const std::vector<array>& outputs);
92
99 virtual std::pair<std::vector<array>, std::vector<int>> vmap(
100 const std::vector<array>& inputs,
101 const std::vector<int>& axes);
102
104 virtual void print(std::ostream& os) = 0;
105
107 virtual bool is_equivalent(const Primitive& other) const {
108 return false;
109 }
110
113 virtual std::vector<std::vector<int>> output_shapes(
114 const std::vector<array>& inputs);
115
116 virtual ~Primitive() = default;
117 Primitive(const Primitive& other) = delete;
118 Primitive(Primitive&& other) = delete;
119 Primitive& operator=(const Primitive& other) = delete;
120 Primitive& operator=(Primitive&& other) = delete;
121
122 private:
123 // Every primitive stores the stream it should run in
124 Stream stream_;
125};
126
127class UnaryPrimitive : public Primitive {
131 public:
133
134 virtual void eval_cpu(const std::vector<array>& inputs, array& output) = 0;
135 virtual void eval_gpu(const std::vector<array>& inputs, array& output) = 0;
136
137 inline void eval_cpu(
138 const std::vector<array>& inputs,
139 std::vector<array>& outputs) override {
140 eval_cpu(inputs, outputs[0]);
141 }
142 inline void eval_gpu(
143 const std::vector<array>& inputs,
144 std::vector<array>& outputs) override {
145 eval_gpu(inputs, outputs[0]);
146 }
147
148 virtual ~UnaryPrimitive() = default;
149 UnaryPrimitive(const UnaryPrimitive& other) = delete;
151 UnaryPrimitive& operator=(const UnaryPrimitive& other) = delete;
153};
154
155class Abs : public UnaryPrimitive {
156 public:
158
159 void eval_cpu(const std::vector<array>& inputs, array& out) override;
160 void eval_gpu(const std::vector<array>& inputs, array& out) override;
161
167
168 private:
169 void eval(const std::vector<array>& inputs, array& out);
170};
171
172class Add : public UnaryPrimitive {
173 public:
175
176 void eval_cpu(const std::vector<array>& inputs, array& out) override;
177 void eval_gpu(const std::vector<array>& inputs, array& out) override;
178
184
185 private:
186 void eval(const std::vector<array>& inputs, array& out);
187};
188
189class AddMM : public UnaryPrimitive {
190 public:
191 explicit AddMM(Stream stream, float alpha, float beta)
192 : UnaryPrimitive(stream), alpha_(alpha), beta_(beta) {}
193
194 void eval_cpu(const std::vector<array>& inputs, array& out) override;
195 void eval_gpu(const std::vector<array>& inputs, array& out) override;
196
197 std::vector<array> vjp(
198 const std::vector<array>& primals,
199 const std::vector<array>& cotangents,
200 const std::vector<int>& argnums,
201 const std::vector<array>& outputs) override;
202
205
206 bool is_equivalent(const Primitive& other) const override;
207
208 private:
209 const float alpha_;
210 const float beta_;
211};
212
213class Arange : public UnaryPrimitive {
214 public:
215 explicit Arange(Stream stream, double start, double stop, double step)
216 : UnaryPrimitive(stream), start_(start), stop_(stop), step_(step) {}
217
218 void eval_cpu(const std::vector<array>& inputs, array& out) override;
219 void eval_gpu(const std::vector<array>& inputs, array& out) override;
220
222 bool is_equivalent(const Primitive& other) const override;
223
224 private:
225 double start_;
226 double stop_;
227 double step_;
228
229 void eval(const std::vector<array>& inputs, array& out);
230};
231
232class ArcCos : public UnaryPrimitive {
233 public:
235
236 void eval_cpu(const std::vector<array>& inputs, array& out) override;
237 void eval_gpu(const std::vector<array>& inputs, array& out) override;
238
244
245 private:
246 void eval(const std::vector<array>& inputs, array& out);
247};
248
249class ArcCosh : public UnaryPrimitive {
250 public:
252
253 void eval_cpu(const std::vector<array>& inputs, array& out) override;
254 void eval_gpu(const std::vector<array>& inputs, array& out) override;
255
261
262 private:
263 void eval(const std::vector<array>& inputs, array& out);
264};
265
266class ArcSin : public UnaryPrimitive {
267 public:
269
270 void eval_cpu(const std::vector<array>& inputs, array& out) override;
271 void eval_gpu(const std::vector<array>& inputs, array& out) override;
272
278
279 private:
280 void eval(const std::vector<array>& inputs, array& out);
281};
282
283class ArcSinh : public UnaryPrimitive {
284 public:
286
287 void eval_cpu(const std::vector<array>& inputs, array& out) override;
288 void eval_gpu(const std::vector<array>& inputs, array& out) override;
289
295
296 private:
297 void eval(const std::vector<array>& inputs, array& out);
298};
299
300class ArcTan : public UnaryPrimitive {
301 public:
303
304 void eval_cpu(const std::vector<array>& inputs, array& out) override;
305 void eval_gpu(const std::vector<array>& inputs, array& out) override;
306
312
313 private:
314 void eval(const std::vector<array>& inputs, array& out);
315};
316
317class ArcTan2 : public UnaryPrimitive {
318 public:
320
321 void eval_cpu(const std::vector<array>& inputs, array& out) override;
322 void eval_gpu(const std::vector<array>& inputs, array& out) override;
323
329
330 private:
331 void eval(const std::vector<array>& inputs, array& out);
332};
333
334class ArcTanh : public UnaryPrimitive {
335 public:
337
338 void eval_cpu(const std::vector<array>& inputs, array& out) override;
339 void eval_gpu(const std::vector<array>& inputs, array& out) override;
340
346
347 private:
348 void eval(const std::vector<array>& inputs, array& out);
349};
350
352 public:
353 explicit ArgPartition(Stream stream, int kth, int axis)
354 : UnaryPrimitive(stream), kth_(kth), axis_(axis) {}
355
356 void eval_cpu(const std::vector<array>& inputs, array& out) override;
357 void eval_gpu(const std::vector<array>& inputs, array& out) override;
358
363 bool is_equivalent(const Primitive& other) const override;
364
365 private:
366 int kth_;
367 int axis_;
368
369 void eval(const std::vector<array>& inputs, array& out);
370};
371
372class ArgReduce : public UnaryPrimitive {
373 public:
378
379 explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis)
380 : UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {}
381
382 void eval_cpu(const std::vector<array>& inputs, array& out) override;
383 void eval_gpu(const std::vector<array>& inputs, array& out) override;
384
388 bool is_equivalent(const Primitive& other) const override;
389 std::vector<std::vector<int>> output_shapes(
390 const std::vector<array>& inputs) override;
391
392 private:
393 ReduceType reduce_type_;
394 int axis_;
395
396 void eval(const std::vector<array>& inputs, array& out);
397};
398
399class ArgSort : public UnaryPrimitive {
400 public:
401 explicit ArgSort(Stream stream, int axis)
402 : UnaryPrimitive(stream), axis_(axis) {}
403
404 void eval_cpu(const std::vector<array>& inputs, array& out) override;
405 void eval_gpu(const std::vector<array>& inputs, array& out) override;
406
410 bool is_equivalent(const Primitive& other) const override;
411
412 private:
413 int axis_;
414
415 void eval(const std::vector<array>& inputs, array& out);
416};
417
418class AsType : public UnaryPrimitive {
419 public:
420 explicit AsType(Stream stream, Dtype dtype)
421 : UnaryPrimitive(stream), dtype_(dtype) {}
422
423 void eval_cpu(const std::vector<array>& inputs, array& out) override;
424 void eval_gpu(const std::vector<array>& inputs, array& out) override;
425
430 bool is_equivalent(const Primitive& other) const override;
431
432 private:
433 Dtype dtype_;
434
435 void eval(const std::vector<array>& inputs, array& out);
436};
437
438class AsStrided : public UnaryPrimitive {
439 public:
440 explicit AsStrided(
442 std::vector<int> shape,
443 std::vector<size_t> strides,
444 size_t offset)
446 shape_(std::move(shape)),
447 strides_(std::move(strides)),
448 offset_(offset) {}
449
450 void eval_cpu(const std::vector<array>& inputs, array& out) override;
451 void eval_gpu(const std::vector<array>& inputs, array& out) override;
452
455 bool is_equivalent(const Primitive& other) const override;
456
457 private:
458 std::vector<int> shape_;
459 std::vector<size_t> strides_;
460 size_t offset_;
461
462 void eval(const std::vector<array>& inputs, array& out);
463};
464
466 public:
467 enum Op { And, Or, Xor, LeftShift, RightShift };
468
471
472 void eval_cpu(const std::vector<array>& inputs, array& out) override;
473 void eval_gpu(const std::vector<array>& inputs, array& out) override;
474
477 bool is_equivalent(const Primitive& other) const override;
478 void print(std::ostream& os) override;
480
481 private:
482 Op op_;
483};
484
486 public:
487 explicit BlockMaskedMM(Stream stream, int block_size)
488 : UnaryPrimitive(stream), block_size_(block_size) {}
489
490 void eval_cpu(const std::vector<array>& inputs, array& out) override;
491 void eval_gpu(const std::vector<array>& inputs, array& out) override;
492
493 std::vector<array> vjp(
494 const std::vector<array>& primals,
495 const std::vector<array>& cotangents,
496 const std::vector<int>& argnums,
497 const std::vector<array>& outputs) override;
498
500 bool is_equivalent(const Primitive& other) const override;
501
502 private:
503 int block_size_;
504
505 void eval(const std::vector<array>& inputs, array& out);
506};
507
508class GatherMM : public UnaryPrimitive {
509 public:
511
512 void eval_cpu(const std::vector<array>& inputs, array& out) override;
513 void eval_gpu(const std::vector<array>& inputs, array& out) override;
514
515 std::vector<array> vjp(
516 const std::vector<array>& primals,
517 const std::vector<array>& cotangents,
518 const std::vector<int>& argnums,
519 const std::vector<array>& outputs) override;
520
523
524 private:
525 void eval(const std::vector<array>& inputs, array& out);
526};
527
528class Broadcast : public UnaryPrimitive {
529 public:
530 explicit Broadcast(Stream stream, const std::vector<int>& shape)
531 : UnaryPrimitive(stream), shape_(shape) {}
532
533 void eval_cpu(const std::vector<array>& inputs, array& out) override;
534 void eval_gpu(const std::vector<array>& inputs, array& out) override;
535
539 bool is_equivalent(const Primitive& other) const override;
540
541 private:
542 std::vector<int> shape_;
543
544 void eval(const std::vector<array>& inputs, array& out);
545};
546
547class Ceil : public UnaryPrimitive {
548 public:
550
551 void eval_cpu(const std::vector<array>& inputs, array& out) override;
552 void eval_gpu(const std::vector<array>& inputs, array& out) override;
553
559
560 private:
561 void eval(const std::vector<array>& inputs, array& out);
562};
563
564class Compiled : public Primitive {
565 public:
566 /*
567 * The inputs, outputs and tape are either tracers or constants.
568 * - The tape should not contain the inputs, but it should contain the
569 * outputs.
570 * - The tape should also have only one array per primitive for multi-output
571 * primitives.
572 * - The constant_ids contains ids of arrays in the input list that are safe
573 * to treat as scalar constants.
574 */
575 explicit Compiled(
577 std::vector<array> inputs,
578 std::vector<array> outputs,
579 std::vector<array> tape,
580 std::unordered_set<uintptr_t> constant_ids);
581
582 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
583 override;
584 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
585 override;
586
589 std::vector<std::vector<int>> output_shapes(
590 const std::vector<array>& inputs) override;
591 void print(std::ostream& os) override;
592 bool is_equivalent(const Primitive& other) const override;
593
594 std::string lib_name() const {
595 return kernel_lib_;
596 }
597
598 private:
599 const std::vector<array> inputs_;
600 const std::vector<array> outputs_;
601 const std::vector<array> tape_;
602 const std::unordered_set<uintptr_t> constant_ids_;
603
604 std::string kernel_lib_;
605};
606
608 public:
609 explicit Concatenate(Stream stream, int axis)
610 : UnaryPrimitive(stream), axis_(axis) {}
611
612 void eval_cpu(const std::vector<array>& inputs, array& out) override;
613 void eval_gpu(const std::vector<array>& inputs, array& out) override;
614
618 bool is_equivalent(const Primitive& other) const override;
619
620 private:
621 int axis_;
622
623 void eval(const std::vector<array>& inputs, array& out);
624};
625
626class Conjugate : public UnaryPrimitive {
627 public:
628 explicit Conjugate(Stream stream) : UnaryPrimitive(stream) {}
629
630 void eval_cpu(const std::vector<array>& inputs, array& out) override;
631 void eval_gpu(const std::vector<array>& inputs, array& out) override;
632
637
638 private:
639 void eval(const std::vector<array>& inputs, array& out);
640};
641
643 public:
644 explicit Contiguous(Stream stream, bool allow_col_major)
645 : UnaryPrimitive(stream), allow_col_major_(allow_col_major) {}
646
647 void eval_cpu(const std::vector<array>& inputs, array& out) override;
648 void eval_gpu(const std::vector<array>& inputs, array& out) override;
649
654
655 bool is_equivalent(const Primitive& other) const override;
656
657 private:
658 bool allow_col_major_;
659};
660
662 public:
663 explicit Convolution(
664 Stream stream,
665 const std::vector<int>& kernel_strides,
666 const std::vector<int>& padding,
667 const std::vector<int>& kernel_dilation,
668 const std::vector<int>& input_dilation,
669 const int groups = 1,
670 const bool flip = false)
671 : UnaryPrimitive(stream),
672 padding_(padding),
673 kernel_strides_(kernel_strides),
674 kernel_dilation_(kernel_dilation),
675 input_dilation_(input_dilation),
676 groups_(groups),
677 flip_(flip) {}
678
679 void eval_cpu(const std::vector<array>& inputs, array& out) override;
680 void eval_gpu(const std::vector<array>& inputs, array& out) override;
681
682 std::vector<array> vjp(
683 const std::vector<array>& primals,
684 const std::vector<array>& cotangents,
685 const std::vector<int>& argnums,
686 const std::vector<array>& outputs) override;
687
689 bool is_equivalent(const Primitive& other) const override;
690
691 private:
692 std::vector<int> padding_;
693 std::vector<int> kernel_strides_;
694 std::vector<int> kernel_dilation_;
695 std::vector<int> input_dilation_;
696 int groups_;
697 bool flip_;
698
699 void eval(const std::vector<array>& inputs, array& out);
700};
701
702class Copy : public UnaryPrimitive {
703 public:
704 explicit Copy(Stream stream) : UnaryPrimitive(stream) {}
705
706 void eval_cpu(const std::vector<array>& inputs, array& out) override;
707 void eval_gpu(const std::vector<array>& inputs, array& out) override;
708
714
715 private:
716 void eval(const std::vector<array>& inputs, array& out);
717};
718
719class Cos : public UnaryPrimitive {
720 public:
721 explicit Cos(Stream stream) : UnaryPrimitive(stream) {}
722
723 void eval_cpu(const std::vector<array>& inputs, array& out) override;
724 void eval_gpu(const std::vector<array>& inputs, array& out) override;
725
731
732 private:
733 void eval(const std::vector<array>& inputs, array& out);
734};
735
736class Cosh : public UnaryPrimitive {
737 public:
738 explicit Cosh(Stream stream) : UnaryPrimitive(stream) {}
739
740 void eval_cpu(const std::vector<array>& inputs, array& out) override;
741 void eval_gpu(const std::vector<array>& inputs, array& out) override;
742
748
749 private:
750 void eval(const std::vector<array>& inputs, array& out);
751};
752
754 public:
756 Stream stream,
757 int num_outputs,
758 std::function<std::vector<array>(
759 const std::vector<array>&,
760 const std::vector<array>&,
761 const std::vector<array>&)> vjp,
762 std::function<std::vector<array>(
763 const std::vector<array>&,
764 const std::vector<array>&,
765 const std::vector<int>&)> jvp,
766 std::function<std::pair<std::vector<array>, std::vector<int>>(
767 const std::vector<array>&,
768 const std::vector<int>&)> vmap)
769 : Primitive(stream),
770 num_outputs_(num_outputs),
771 vjp_fun_(std::move(vjp)),
772 jvp_fun_(std::move(jvp)),
773 vmap_fun_(std::move(vmap)) {}
774
775 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
776 override;
777 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
778 override;
779
783
784 private:
785 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
786
787 int num_outputs_;
788
789 std::function<std::vector<array>(
790 const std::vector<array>&,
791 const std::vector<array>&,
792 const std::vector<array>&)>
793 vjp_fun_;
794 std::function<std::vector<array>(
795 const std::vector<array>&,
796 const std::vector<array>&,
797 const std::vector<int>&)>
798 jvp_fun_;
799 std::function<std::pair<std::vector<array>, std::vector<int>>(
800 const std::vector<array>&,
801 const std::vector<int>&)>
802 vmap_fun_;
803};
804
805class Depends : public Primitive {
806 public:
807 explicit Depends(Stream stream) : Primitive(stream) {}
808
809 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
810 override;
811 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
812 override;
813
814 std::vector<array> vjp(
815 const std::vector<array>& primals,
816 const std::vector<array>& cotan,
817 const std::vector<int>& argnums,
818 const std::vector<array>& outputs) override;
819
821
822 private:
823 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
824};
825
826class Divide : public UnaryPrimitive {
827 public:
828 explicit Divide(Stream stream) : UnaryPrimitive(stream) {}
829
830 void eval_cpu(const std::vector<array>& inputs, array& out) override;
831 void eval_gpu(const std::vector<array>& inputs, array& out) override;
832
838
839 private:
840 void eval(const std::vector<array>& inputs, array& out);
841};
842
843class DivMod : public Primitive {
844 public:
845 explicit DivMod(Stream stream) : Primitive(stream) {}
846
847 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
848 override;
849 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
850 override;
851
856 std::vector<std::vector<int>> output_shapes(
857 const std::vector<array>& inputs) override {
858 return std::vector{inputs[0].shape(), inputs[0].shape()};
859 }
860
861 private:
862 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
863};
864
865class Select : public UnaryPrimitive {
866 public:
867 explicit Select(Stream stream) : UnaryPrimitive(stream) {}
868
869 void eval_cpu(const std::vector<array>& inputs, array& out) override;
870 void eval_gpu(const std::vector<array>& inputs, array& out) override;
871
877
878 private:
879 void eval(const std::vector<array>& inputs, array& out);
880};
881
882class Remainder : public UnaryPrimitive {
883 public:
884 explicit Remainder(Stream stream) : UnaryPrimitive(stream) {}
885
886 void eval_cpu(const std::vector<array>& inputs, array& out) override;
887 void eval_gpu(const std::vector<array>& inputs, array& out) override;
888
894
895 private:
896 void eval(const std::vector<array>& inputs, array& out);
897};
898
899class Equal : public UnaryPrimitive {
900 public:
901 explicit Equal(Stream stream, bool equal_nan = false)
902 : UnaryPrimitive(stream), equal_nan_(equal_nan) {}
903
904 void eval_cpu(const std::vector<array>& inputs, array& out) override;
905 void eval_gpu(const std::vector<array>& inputs, array& out) override;
906
911
912 void print(std::ostream& os) override {
913 if (equal_nan_) {
914 os << "NaNEqual";
915 } else {
916 os << "Equal";
917 }
918 }
919
920 private:
921 void eval(const std::vector<array>& inputs, array& out);
922 bool equal_nan_;
923};
924
925class Erf : public UnaryPrimitive {
926 public:
927 explicit Erf(Stream stream) : UnaryPrimitive(stream) {}
928
929 void eval_cpu(const std::vector<array>& inputs, array& out) override;
930 void eval_gpu(const std::vector<array>& inputs, array& out) override;
931
937
938 private:
939 void eval(const std::vector<array>& inputs, array& out);
940};
941
942class ErfInv : public UnaryPrimitive {
943 public:
944 explicit ErfInv(Stream stream) : UnaryPrimitive(stream) {}
945
946 void eval_cpu(const std::vector<array>& inputs, array& out) override;
947 void eval_gpu(const std::vector<array>& inputs, array& out) override;
948
954
955 private:
956 void eval(const std::vector<array>& inputs, array& out);
957};
958
959class Exp : public UnaryPrimitive {
960 public:
961 explicit Exp(Stream stream) : UnaryPrimitive(stream) {}
962
963 void eval_cpu(const std::vector<array>& inputs, array& out) override;
964 void eval_gpu(const std::vector<array>& inputs, array& out) override;
965
971
972 private:
973 void eval(const std::vector<array>& inputs, array& out);
974};
975
976class Expm1 : public UnaryPrimitive {
977 public:
978 explicit Expm1(Stream stream) : UnaryPrimitive(stream) {}
979
980 void eval_cpu(const std::vector<array>& inputs, array& out) override;
981 void eval_gpu(const std::vector<array>& inputs, array& out) override;
982
987
988 private:
989 void eval(const std::vector<array>& inputs, array& out);
990};
991
992class FFT : public UnaryPrimitive {
993 public:
994 explicit FFT(
995 Stream stream,
996 const std::vector<size_t>& axes,
997 bool inverse,
998 bool real)
999 : UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {}
1000
1001 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1002 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1003
1007
1008 bool is_equivalent(const Primitive& other) const override;
1009
1010 private:
1011 std::vector<size_t> axes_;
1012 bool inverse_;
1013 bool real_;
1014
1015 void eval(const std::vector<array>& inputs, array& out);
1016};
1017
1018class Floor : public UnaryPrimitive {
1019 public:
1020 explicit Floor(Stream stream) : UnaryPrimitive(stream) {}
1021
1022 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1023 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1024
1030
1031 private:
1032 void eval(const std::vector<array>& inputs, array& out);
1033};
1034
1035class Full : public UnaryPrimitive {
1036 public:
1037 explicit Full(Stream stream) : UnaryPrimitive(stream) {}
1038
1039 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1040 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1041
1046
1047 private:
1048 void eval(const std::vector<array>& inputs, array& out);
1049};
1050
1051class Gather : public UnaryPrimitive {
1052 public:
1053 explicit Gather(
1054 Stream stream,
1055 const std::vector<int>& axes,
1056 const std::vector<int>& slice_sizes)
1057 : UnaryPrimitive(stream), axes_(axes), slice_sizes_(slice_sizes) {}
1058
1059 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1060 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1061
1065 bool is_equivalent(const Primitive& other) const override;
1066
1067 private:
1068 void eval(const std::vector<array>& inputs, array& out);
1069 std::vector<int> axes_;
1070 std::vector<int> slice_sizes_;
1071};
1072
1073class Greater : public UnaryPrimitive {
1074 public:
1075 explicit Greater(Stream stream) : UnaryPrimitive(stream) {}
1076
1077 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1078 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1079
1085
1086 private:
1087 void eval(const std::vector<array>& inputs, array& out);
1088};
1089
1091 public:
1092 explicit GreaterEqual(Stream stream) : UnaryPrimitive(stream) {}
1093
1094 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1095 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1096
1102
1103 private:
1104 void eval(const std::vector<array>& inputs, array& out);
1105};
1106
1107class Hadamard : public UnaryPrimitive {
1108 public:
1109 explicit Hadamard(Stream stream, float scale)
1110 : UnaryPrimitive(stream), scale_(scale) {}
1111
1112 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1113 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1114
1119
1120 bool is_equivalent(const Primitive& other) const override;
1121
1122 private:
1123 float scale_;
1124
1125 void eval(const std::vector<array>& inputs, array& out);
1126};
1127
1128class Imag : public UnaryPrimitive {
1129 public:
1130 explicit Imag(Stream stream) : UnaryPrimitive(stream) {}
1131
1132 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1133 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1134
1140};
1141
1142class Less : public UnaryPrimitive {
1143 public:
1144 explicit Less(Stream stream) : UnaryPrimitive(stream) {}
1145
1146 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1147 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1148
1154
1155 private:
1156 void eval(const std::vector<array>& inputs, array& out);
1157};
1158
1160 public:
1161 explicit LessEqual(Stream stream) : UnaryPrimitive(stream) {}
1162
1163 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1164 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1165
1171
1172 private:
1173 void eval(const std::vector<array>& inputs, array& out);
1174};
1175
1176class Load : public UnaryPrimitive {
1177 public:
1178 explicit Load(
1179 Stream stream,
1180 std::shared_ptr<io::Reader> reader,
1181 size_t offset,
1182 bool swap_endianness = false)
1183 : UnaryPrimitive(stream),
1184 reader_(std::move(reader)),
1185 offset_(offset),
1186 swap_endianness_(swap_endianness) {
1187 if (stream.device == Device::gpu) {
1188 io_stream();
1189 }
1190 }
1191
1192 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1193 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1194
1196
1197 private:
1198 Stream& io_stream() {
1199 static Stream io_stream = new_stream(Device::cpu);
1200 return io_stream;
1201 };
1202 void eval(const std::vector<array>& inputs, array& out);
1203 std::shared_ptr<io::Reader> reader_;
1204 size_t offset_;
1205 bool swap_endianness_;
1206};
1207
1208class Log : public UnaryPrimitive {
1209 public:
1210 enum Base { two, ten, e };
1211
1212 explicit Log(Stream stream, Base base)
1213 : UnaryPrimitive(stream), base_(base) {}
1214
1215 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1216 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1217
1222
1223 void print(std::ostream& os) override {
1224 switch (base_) {
1225 case e:
1226 os << "Log";
1227 break;
1228 case two:
1229 os << "Log2";
1230 break;
1231 case ten:
1232 os << "Log10";
1233 break;
1234 }
1235 }
1236
1237 private:
1238 Base base_;
1239 void eval(const std::vector<array>& inputs, array& out);
1240};
1241
1242class Log1p : public UnaryPrimitive {
1243 public:
1244 explicit Log1p(Stream stream) : UnaryPrimitive(stream) {}
1245
1246 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1247 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1248
1253
1254 private:
1255 void eval(const std::vector<array>& inputs, array& out);
1256};
1257
1259 public:
1260 explicit LogicalNot(Stream stream) : UnaryPrimitive(stream) {}
1261
1262 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1263 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1264
1270
1271 private:
1272 void eval(const std::vector<array>& inputs, array& out);
1273};
1274
1276 public:
1277 explicit LogicalAnd(Stream stream) : UnaryPrimitive(stream) {}
1278
1279 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1280 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1281
1287
1288 private:
1289 void eval(const std::vector<array>& inputs, array& out);
1290};
1291
1293 public:
1294 explicit LogicalOr(Stream stream) : UnaryPrimitive(stream) {}
1295
1296 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1297 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1298
1304
1305 private:
1306 void eval(const std::vector<array>& inputs, array& out);
1307};
1308
1310 public:
1311 explicit LogAddExp(Stream stream) : UnaryPrimitive(stream) {}
1312
1313 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1314 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1315
1321
1322 private:
1323 void eval(const std::vector<array>& inputs, array& out);
1324};
1325
1326class Matmul : public UnaryPrimitive {
1327 public:
1328 explicit Matmul(Stream stream) : UnaryPrimitive(stream) {}
1329
1330 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1331 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1332
1333 std::vector<array> vjp(
1334 const std::vector<array>& primals,
1335 const std::vector<array>& cotangents,
1336 const std::vector<int>& argnums,
1337 const std::vector<array>& outputs) override;
1338
1342};
1343
1344class Maximum : public UnaryPrimitive {
1345 public:
1346 explicit Maximum(Stream stream) : UnaryPrimitive(stream) {}
1347
1348 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1349 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1350
1356
1357 private:
1358 void eval(const std::vector<array>& inputs, array& out);
1359};
1360
1361class Minimum : public UnaryPrimitive {
1362 public:
1363 explicit Minimum(Stream stream) : UnaryPrimitive(stream) {}
1364
1365 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1366 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1367
1373
1374 private:
1375 void eval(const std::vector<array>& inputs, array& out);
1376};
1377
1378class Multiply : public UnaryPrimitive {
1379 public:
1380 explicit Multiply(Stream stream) : UnaryPrimitive(stream) {}
1381
1382 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1383 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1384
1390
1391 private:
1392 void eval(const std::vector<array>& inputs, array& out);
1393};
1394
1395class Negative : public UnaryPrimitive {
1396 public:
1397 explicit Negative(Stream stream) : UnaryPrimitive(stream) {}
1398
1399 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1400 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1401
1407
1408 private:
1409 void eval(const std::vector<array>& inputs, array& out);
1410};
1411
1412class NotEqual : public UnaryPrimitive {
1413 public:
1414 explicit NotEqual(Stream stream) : UnaryPrimitive(stream) {}
1415
1416 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1417 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1418
1424
1425 private:
1426 void eval(const std::vector<array>& inputs, array& out);
1427};
1428
1430 public:
1432 Stream stream,
1433 std::vector<int> axes,
1434 bool inverted,
1435 Dtype dtype)
1436 : UnaryPrimitive(stream),
1437 axes_(std::move(axes)),
1438 inverted_(inverted),
1439 dtype_(dtype) {}
1440
1441 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1442 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1443
1446 bool is_equivalent(const Primitive& other) const override;
1447 std::vector<std::vector<int>> output_shapes(
1448 const std::vector<array>& inputs) override {
1449 return {{}};
1450 }
1451
1452 private:
1453 std::vector<int> axes_;
1454 bool inverted_;
1455 Dtype dtype_;
1456
1457 void eval(const std::vector<array>& inputs, array& out);
1458};
1459
1460class Pad : public UnaryPrimitive {
1461 public:
1462 explicit Pad(
1463 Stream stream,
1464 const std::vector<int>& axes,
1465 const std::vector<int>& low_pad_size,
1466 const std::vector<int>& high_pad_size)
1467 : UnaryPrimitive(stream),
1468 axes_(axes),
1469 low_pad_size_(low_pad_size),
1470 high_pad_size_(high_pad_size) {}
1471
1472 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1473 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1474
1478 bool is_equivalent(const Primitive& other) const override;
1479
1480 private:
1481 std::vector<int> axes_;
1482 std::vector<int> low_pad_size_;
1483 std::vector<int> high_pad_size_;
1484
1485 void eval(const std::vector<array>& inputs, array& out);
1486};
1487
1489 public:
1490 explicit Partition(Stream stream, int kth, int axis)
1491 : UnaryPrimitive(stream), kth_(kth), axis_(axis) {}
1492
1493 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1494 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1495
1500 bool is_equivalent(const Primitive& other) const override;
1501
1502 private:
1503 int kth_;
1504 int axis_;
1505
1506 void eval(const std::vector<array>& inputs, array& out);
1507};
1508
1509class Power : public UnaryPrimitive {
1510 public:
1511 explicit Power(Stream stream) : UnaryPrimitive(stream) {}
1512
1513 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1514 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1515
1521
1522 private:
1523 void eval(const std::vector<array>& inputs, array& out);
1524};
1525
1527 public:
1529 Stream stream,
1530 int group_size,
1531 int bits,
1532 bool transpose)
1533 : UnaryPrimitive(stream),
1534 group_size_(group_size),
1535 bits_(bits),
1536 transpose_(transpose) {}
1537
1538 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1539 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1540
1544 bool is_equivalent(const Primitive& other) const override;
1545
1546 private:
1547 int group_size_;
1548 int bits_;
1549 bool transpose_;
1550
1551 void eval(const std::vector<array>& inputs, array& out);
1552};
1553
1555 public:
1556 explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose)
1557 : UnaryPrimitive(stream),
1558 group_size_(group_size),
1559 bits_(bits),
1560 transpose_(transpose) {}
1561
1562 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1563 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1564
1568 bool is_equivalent(const Primitive& other) const override;
1569
1570 private:
1571 int group_size_;
1572 int bits_;
1573 bool transpose_;
1574
1575 void eval(const std::vector<array>& inputs, array& out);
1576};
1577
1579 public:
1580 explicit RandomBits(Stream stream, const std::vector<int>& shape, int width)
1581 : UnaryPrimitive(stream), shape_(shape), width_(width) {}
1582
1583 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1584 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1585
1588 bool is_equivalent(const Primitive& other) const override;
1589
1590 private:
1591 std::vector<int> shape_;
1592 int width_;
1593
1594 void eval(const std::vector<array>& inputs, array& out);
1595};
1596
1597class Real : public UnaryPrimitive {
1598 public:
1599 explicit Real(Stream stream) : UnaryPrimitive(stream) {}
1600
1601 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1602 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1603
1609};
1610
1611class Reshape : public UnaryPrimitive {
1612 public:
1613 explicit Reshape(Stream stream, const std::vector<int>& shape)
1614 : UnaryPrimitive(stream), shape_(shape) {}
1615
1616 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1617 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1618
1622 bool is_equivalent(const Primitive& other) const override;
1623
1624 private:
1625 std::vector<int> shape_;
1626
1627 void eval(const std::vector<array>& inputs, array& out);
1628
1629 std::pair<bool, std::vector<size_t>> prepare_reshape(
1630 const array& in,
1631 const array& out);
1632 void shared_buffer_reshape(
1633 const array& in,
1634 const std::vector<size_t>& out_strides,
1635 array& out);
1636};
1637
1638class Reduce : public UnaryPrimitive {
1639 public:
1640 enum ReduceType { And, Or, Sum, Prod, Min, Max };
1641
1642 explicit Reduce(
1643 Stream stream,
1644 ReduceType reduce_type,
1645 const std::vector<int>& axes)
1646 : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}
1647
1648 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1649 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1650
1652
1653 std::vector<array> vjp(
1654 const std::vector<array>& primals,
1655 const std::vector<array>& cotangents,
1656 const std::vector<int>& argnums,
1657 const std::vector<array>& outputs) override;
1658
1659 std::vector<std::vector<int>> output_shapes(
1660 const std::vector<array>& inputs) override;
1661
1662 void print(std::ostream& os) override {
1663 switch (reduce_type_) {
1664 case And:
1665 os << "And";
1666 break;
1667 case Or:
1668 os << "Or";
1669 break;
1670 case Sum:
1671 os << "Sum";
1672 break;
1673 case Prod:
1674 os << "Prod";
1675 break;
1676 case Min:
1677 os << "Min";
1678 break;
1679 case Max:
1680 os << "Max";
1681 break;
1682 }
1683 }
1684 bool is_equivalent(const Primitive& other) const override;
1685
1686 private:
1687 ReduceType reduce_type_;
1688 std::vector<int> axes_;
1689
1690 void eval(const std::vector<array>& inputs, array& out);
1691};
1692
1693class Round : public UnaryPrimitive {
1694 public:
1695 explicit Round(Stream stream) : UnaryPrimitive(stream) {}
1696
1697 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1698 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1699
1705
1706 private:
1707 void eval(const std::vector<array>& inputs, array& out);
1708};
1709
1710class Scan : public UnaryPrimitive {
1711 public:
1713
1714 explicit Scan(
1715 Stream stream,
1716 ReduceType reduce_type,
1717 int axis,
1718 bool reverse,
1719 bool inclusive)
1720 : UnaryPrimitive(stream),
1721 reduce_type_(reduce_type),
1722 axis_(axis),
1723 reverse_(reverse),
1724 inclusive_(inclusive) {}
1725
1726 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1727 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1728
1731
1732 void print(std::ostream& os) override {
1733 os << "Cum";
1734 switch (reduce_type_) {
1735 case Sum:
1736 os << "Sum";
1737 break;
1738 case Prod:
1739 os << "Prod";
1740 break;
1741 case Min:
1742 os << "Min";
1743 break;
1744 case Max:
1745 os << "Max";
1746 break;
1747 }
1748 }
1749 bool is_equivalent(const Primitive& other) const override;
1750
1751 private:
1752 ReduceType reduce_type_;
1753 int axis_;
1754 bool reverse_;
1755 bool inclusive_;
1756
1757 void eval(const std::vector<array>& inputs, array& out);
1758};
1759
1760class Scatter : public UnaryPrimitive {
1761 public:
1763
1764 explicit Scatter(
1765 Stream stream,
1766 ReduceType reduce_type,
1767 const std::vector<int>& axes)
1768 : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}
1769
1770 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1771 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1772
1775
1776 void print(std::ostream& os) override {
1777 os << "Scatter";
1778 switch (reduce_type_) {
1779 case Sum:
1780 os << " Sum";
1781 break;
1782 case Prod:
1783 os << " Prod";
1784 break;
1785 case Min:
1786 os << " Min";
1787 break;
1788 case Max:
1789 os << " Max";
1790 break;
1791 case None:
1792 break;
1793 }
1794 }
1795 bool is_equivalent(const Primitive& other) const override;
1796
1797 private:
1798 void eval(const std::vector<array>& inputs, array& out);
1799 ReduceType reduce_type_;
1800 std::vector<int> axes_;
1801};
1802
1803class Sigmoid : public UnaryPrimitive {
1804 public:
1805 explicit Sigmoid(Stream stream) : UnaryPrimitive(stream) {}
1806
1807 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1808 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1809
1815
1816 private:
1817 void eval(const std::vector<array>& inputs, array& out);
1818};
1819
1820class Sign : public UnaryPrimitive {
1821 public:
1822 explicit Sign(Stream stream) : UnaryPrimitive(stream) {}
1823
1824 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1825 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1826
1832
1833 private:
1834 void eval(const std::vector<array>& inputs, array& out);
1835};
1836
1837class Sin : public UnaryPrimitive {
1838 public:
1839 explicit Sin(Stream stream) : UnaryPrimitive(stream) {}
1840
1841 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1842 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1843
1849
1850 private:
1851 void eval(const std::vector<array>& inputs, array& out);
1852};
1853
1854class Sinh : public UnaryPrimitive {
1855 public:
1856 explicit Sinh(Stream stream) : UnaryPrimitive(stream) {}
1857
1858 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1859 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1860
1866
1867 private:
1868 void eval(const std::vector<array>& inputs, array& out);
1869};
1870
1871class Slice : public UnaryPrimitive {
1872 public:
1873 explicit Slice(
1874 Stream stream,
1875 const std::vector<int>& start_indices,
1876 const std::vector<int>& end_indices,
1877 const std::vector<int>& strides)
1878 : UnaryPrimitive(stream),
1879 start_indices_(start_indices),
1880 end_indices_(end_indices),
1881 strides_(strides) {}
1882
1883 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1884 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1885
1889 bool is_equivalent(const Primitive& other) const override;
1890
1891 private:
1892 std::vector<int> start_indices_;
1893 std::vector<int> end_indices_;
1894 std::vector<int> strides_;
1895
1896 void eval(const std::vector<array>& inputs, array& out);
1897};
1898
1900 public:
1901 explicit SliceUpdate(
1902 Stream stream,
1903 const std::vector<int>& start_indices,
1904 const std::vector<int>& end_indices,
1905 const std::vector<int>& strides)
1906 : UnaryPrimitive(stream),
1907 start_indices_(start_indices),
1908 end_indices_(end_indices),
1909 strides_(strides) {}
1910
1911 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1912 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1913
1917 bool is_equivalent(const Primitive& other) const override;
1918
1919 private:
1920 std::vector<int> start_indices_;
1921 std::vector<int> end_indices_;
1922 std::vector<int> strides_;
1923
1924 void eval(const std::vector<array>& inputs, array& out);
1925
1926 std::tuple<int64_t, std::vector<int64_t>> prepare_slice(const array& in);
1927};
1928
1929class Softmax : public UnaryPrimitive {
1930 public:
1931 explicit Softmax(Stream stream, bool precise)
1932 : UnaryPrimitive(stream), precise_(precise) {}
1933
1934 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1935 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1936
1941
1942 bool is_equivalent(const Primitive& other) const override;
1943
1944 private:
1945 void eval(const std::vector<array>& inputs, array& out);
1946 bool precise_;
1947};
1948
1949class Sort : public UnaryPrimitive {
1950 public:
1951 explicit Sort(Stream stream, int axis)
1952 : UnaryPrimitive(stream), axis_(axis) {}
1953
1954 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1955 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1956
1961 bool is_equivalent(const Primitive& other) const override;
1962
1963 private:
1964 int axis_;
1965
1966 void eval(const std::vector<array>& inputs, array& out);
1967};
1968
1969class Split : public Primitive {
1970 public:
1971 explicit Split(Stream stream, const std::vector<int>& indices, int axis)
1972 : Primitive(stream), indices_(indices), axis_(axis) {}
1973
1974 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
1975 override;
1976 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
1977 override;
1978
1982 bool is_equivalent(const Primitive& other) const override;
1983
1984 private:
1985 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
1986
1987 std::vector<int> indices_;
1988 int axis_;
1989};
1990
1991class Square : public UnaryPrimitive {
1992 public:
1993 explicit Square(Stream stream) : UnaryPrimitive(stream) {}
1994
1995 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1996 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1997
2003
2004 private:
2005 void eval(const std::vector<array>& inputs, array& out);
2006};
2007
2008class Sqrt : public UnaryPrimitive {
2009 public:
2010 explicit Sqrt(Stream stream, bool recip = false)
2011 : UnaryPrimitive(stream), recip_(recip) {}
2012
2013 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2014 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2015
2019 bool is_equivalent(const Primitive& other) const override;
2020
2021 void print(std::ostream& os) override {
2022 if (recip_) {
2023 os << "Rsqrt";
2024 } else {
2025 os << "Sqrt";
2026 }
2027 }
2028
2029 private:
2030 void eval(const std::vector<array>& inputs, array& out);
2031 bool recip_;
2032};
2033
2035 public:
2036 explicit StopGradient(Stream stream) : UnaryPrimitive(stream) {}
2037
2038 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2039 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2040
2045
2046 private:
2047 void eval(const std::vector<array>& inputs, array& out);
2048};
2049
2050class Subtract : public UnaryPrimitive {
2051 public:
2052 explicit Subtract(Stream stream) : UnaryPrimitive(stream) {}
2053
2054 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2055 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2056
2062
2063 private:
2064 void eval(const std::vector<array>& inputs, array& out);
2065};
2066
2067class Tan : public UnaryPrimitive {
2068 public:
2069 explicit Tan(Stream stream) : UnaryPrimitive(stream) {}
2070
2071 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2072 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2073
2079
2080 private:
2081 void eval(const std::vector<array>& inputs, array& out);
2082};
2083
2084class Tanh : public UnaryPrimitive {
2085 public:
2086 explicit Tanh(Stream stream) : UnaryPrimitive(stream) {}
2087
2088 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2089 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2090
2096
2097 private:
2098 void eval(const std::vector<array>& inputs, array& out);
2099};
2100
2101class Uniform : public UnaryPrimitive {
2102 public:
2103 explicit Uniform(Stream stream) : UnaryPrimitive(stream) {}
2104
2105 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2106 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2107
2111
2112 private:
2113 void eval(const std::vector<array>& inputs, array& out);
2114};
2115
2116class View : public UnaryPrimitive {
2117 public:
2118 explicit View(Stream stream, Dtype dtype)
2119 : UnaryPrimitive(stream), dtype_(dtype) {}
2120
2121 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2122 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2123
2125 void print(std::ostream& os) override;
2126 bool is_equivalent(const Primitive& other) const override;
2127
2128 private:
2129 Dtype dtype_;
2130};
2131
2133 public:
2134 explicit Transpose(Stream stream, const std::vector<int>& axes)
2135 : UnaryPrimitive(stream), axes_(axes) {}
2136
2137 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2138 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2139
2143 bool is_equivalent(const Primitive& other) const override;
2144
2145 private:
2146 std::vector<int> axes_;
2147
2148 void eval(const std::vector<array>& inputs, array& out);
2149};
2150
2151/* QR Factorization primitive. */
2152class QRF : public Primitive {
2153 public:
2154 explicit QRF(Stream stream) : Primitive(stream) {}
2155
2156 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2157 override;
2158 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2159 override;
2160
2162
2163 private:
2164 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
2165};
2166
2167/* SVD primitive. */
2168class SVD : public Primitive {
2169 public:
2170 explicit SVD(Stream stream) : Primitive(stream) {}
2171
2172 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2173 override;
2174 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2175 override;
2176
2179
2180 private:
2181 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
2182};
2183
2184/* Matrix inversion primitive. */
2185class Inverse : public UnaryPrimitive {
2186 public:
2187 explicit Inverse(Stream stream, bool tri, bool upper)
2188 : UnaryPrimitive(stream), tri_(tri), upper_(upper) {}
2189
2190 void eval_cpu(const std::vector<array>& inputs, array& output) override;
2191 void eval_gpu(const std::vector<array>& inputs, array& output) override;
2192
2195
2196 private:
2197 void eval(const std::vector<array>& inputs, array& output);
2198 bool tri_;
2199 bool upper_;
2200};
2201
2202class Cholesky : public UnaryPrimitive {
2203 public:
2204 explicit Cholesky(Stream stream, bool upper)
2205 : UnaryPrimitive(stream), upper_(upper) {}
2206
2207 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2208 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2209
2212
2213 private:
2214 void eval(const std::vector<array>& inputs, array& output);
2215 bool upper_;
2216};
2217
2218class Eigh : public Primitive {
2219 public:
2220 explicit Eigh(Stream stream, std::string uplo, bool compute_eigenvectors)
2221 : Primitive(stream),
2222 uplo_(std::move(uplo)),
2223 compute_eigenvectors_(compute_eigenvectors) {}
2224
2225 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2226 override;
2227 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2228 override;
2229
2232
2233 std::vector<std::vector<int>> output_shapes(
2234 const std::vector<array>& inputs) override {
2235 auto shape = inputs[0].shape();
2236 shape.pop_back(); // Remove last dimension for eigenvalues
2237 if (compute_eigenvectors_) {
2238 return {shape, inputs[0].shape()}; // Eigenvalues and eigenvectors
2239 } else {
2240 return {shape}; // Only eigenvalues
2241 }
2242 }
2243
2244 bool is_equivalent(const Primitive& other) const override {
2245 if (auto* p = dynamic_cast<const Eigh*>(&other)) {
2246 return uplo_ == p->uplo_ &&
2247 compute_eigenvectors_ == p->compute_eigenvectors_;
2248 }
2249 return false;
2250 }
2251
2252 private:
2253 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
2254 std::string uplo_;
2255 bool compute_eigenvectors_;
2256};
2257
2258} // namespace mlx::core
Definition primitives.h:155
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Abs(Stream stream)
Definition primitives.h:157
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:164
std::vector< std::vector< int > > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:166
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:165
Definition primitives.h:172
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Add(Stream stream)
Definition primitives.h:174
Definition primitives.h:189
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
AddMM(Stream stream, float alpha, float beta)
Definition primitives.h:191
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
Definition primitives.h:213
Arange(Stream stream, double start, double stop, double step)
Definition primitives.h:215
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:232
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCos(Stream stream)
Definition primitives.h:234
Definition primitives.h:249
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCosh(Stream stream)
Definition primitives.h:251
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:266
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcSin(Stream stream)
Definition primitives.h:268
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:283
ArcSinh(Stream stream)
Definition primitives.h:285
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:317
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTan2(Stream stream)
Definition primitives.h:319
Definition primitives.h:300
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcTan(Stream stream)
Definition primitives.h:302
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:334
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTanh(Stream stream)
Definition primitives.h:336
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:351
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArgPartition(Stream stream, int kth, int axis)
Definition primitives.h:353
Definition primitives.h:372
ReduceType
Definition primitives.h:374
@ ArgMin
Definition primitives.h:375
@ ArgMax
Definition primitives.h:376
ArgReduce(Stream stream, ReduceType reduce_type, int axis)
Definition primitives.h:379
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:399
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArgSort(Stream stream, int axis)
Definition primitives.h:401
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:438
AsStrided(Stream stream, std::vector< int > shape, std::vector< size_t > strides, size_t offset)
Definition primitives.h:440
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:418
void eval_gpu(const std::vector< array > &inputs, array &out) override
AsType(Stream stream, Dtype dtype)
Definition primitives.h:420
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:465
BitwiseBinary(Stream stream, Op op)
Definition primitives.h:469
void eval_cpu(const std::vector< array > &inputs, array &out) override
Op
Definition primitives.h:467
@ And
Definition primitives.h:467
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:485
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
BlockMaskedMM(Stream stream, int block_size)
Definition primitives.h:487
Definition primitives.h:528
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Broadcast(Stream stream, const std::vector< int > &shape)
Definition primitives.h:530
Definition primitives.h:547
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Ceil(Stream stream)
Definition primitives.h:549
Definition primitives.h:2202
void eval_cpu(const std::vector< array > &inputs, array &out) override
Cholesky(Stream stream, bool upper)
Definition primitives.h:2204
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:564
Compiled(Stream stream, std::vector< array > inputs, std::vector< array > outputs, std::vector< array > tape, std::unordered_set< uintptr_t > constant_ids)
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:607
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Concatenate(Stream stream, int axis)
Definition primitives.h:609
Definition primitives.h:626
Conjugate(Stream stream)
Definition primitives.h:628
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:642
Contiguous(Stream stream, bool allow_col_major)
Definition primitives.h:644
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:661
void eval_gpu(const std::vector< array > &inputs, array &out) override
Convolution(Stream stream, const std::vector< int > &kernel_strides, const std::vector< int > &padding, const std::vector< int > &kernel_dilation, const std::vector< int > &input_dilation, const int groups=1, const bool flip=false)
Definition primitives.h:663
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
Definition primitives.h:702
void eval_gpu(const std::vector< array > &inputs, array &out) override
Copy(Stream stream)
Definition primitives.h:704
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:719
void eval_cpu(const std::vector< array > &inputs, array &out) override
Cos(Stream stream)
Definition primitives.h:721
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:736
void eval_gpu(const std::vector< array > &inputs, array &out) override
Cosh(Stream stream)
Definition primitives.h:738
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:753
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
CustomTransforms(Stream stream, int num_outputs, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> vjp, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< int > &)> jvp, std::function< std::pair< std::vector< array >, std::vector< int > >(const std::vector< array > &, const std::vector< int > &)> vmap)
Definition primitives.h:755
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:805
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotan, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Depends(Stream stream)
Definition primitives.h:807
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:843
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
DivMod(Stream stream)
Definition primitives.h:845
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:826
Divide(Stream stream)
Definition primitives.h:828
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2218
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:2244
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Eigh(Stream stream, std::string uplo, bool compute_eigenvectors)
Definition primitives.h:2220
Definition primitives.h:899
Equal(Stream stream, bool equal_nan=false)
Definition primitives.h:901
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:925
Erf(Stream stream)
Definition primitives.h:927
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:942
void eval_gpu(const std::vector< array > &inputs, array &out) override
ErfInv(Stream stream)
Definition primitives.h:944
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:959
Exp(Stream stream)
Definition primitives.h:961
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:976
Expm1(Stream stream)
Definition primitives.h:978
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:992
FFT(Stream stream, const std::vector< size_t > &axes, bool inverse, bool real)
Definition primitives.h:994
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1018
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Floor(Stream stream)
Definition primitives.h:1020
Definition primitives.h:1035
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Full(Stream stream)
Definition primitives.h:1037
Definition primitives.h:1051
Gather(Stream stream, const std::vector< int > &axes, const std::vector< int > &slice_sizes)
Definition primitives.h:1053
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:508
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, array &out) override
GatherMM(Stream stream)
Definition primitives.h:510
Definition primitives.h:1554
GatherQMM(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1556
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1090
void eval_cpu(const std::vector< array > &inputs, array &out) override
GreaterEqual(Stream stream)
Definition primitives.h:1092
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1073
Greater(Stream stream)
Definition primitives.h:1075
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1107
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Hadamard(Stream stream, float scale)
Definition primitives.h:1109
Definition primitives.h:1128
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Imag(Stream stream)
Definition primitives.h:1130
Definition primitives.h:2185
void eval_gpu(const std::vector< array > &inputs, array &output) override
Inverse(Stream stream, bool tri, bool upper)
Definition primitives.h:2187
void eval_cpu(const std::vector< array > &inputs, array &output) override
Definition primitives.h:1159
LessEqual(Stream stream)
Definition primitives.h:1161
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1142
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Less(Stream stream)
Definition primitives.h:1144
Definition primitives.h:1176
void eval_gpu(const std::vector< array > &inputs, array &out) override
Load(Stream stream, std::shared_ptr< io::Reader > reader, size_t offset, bool swap_endianness=false)
Definition primitives.h:1178
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1242
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Log1p(Stream stream)
Definition primitives.h:1244
Definition primitives.h:1309
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogAddExp(Stream stream)
Definition primitives.h:1311
Definition primitives.h:1208
Base
Definition primitives.h:1210
Log(Stream stream, Base base)
Definition primitives.h:1212
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1275
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalAnd(Stream stream)
Definition primitives.h:1277
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1258
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalNot(Stream stream)
Definition primitives.h:1260
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1292
void eval_cpu(const std::vector< array > &inputs, array &out) override
LogicalOr(Stream stream)
Definition primitives.h:1294
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1326
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, array &out) override
Matmul(Stream stream)
Definition primitives.h:1328
Definition primitives.h:1344
Maximum(Stream stream)
Definition primitives.h:1346
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1361
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Minimum(Stream stream)
Definition primitives.h:1363
Definition primitives.h:1378
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Multiply(Stream stream)
Definition primitives.h:1380
Definition primitives.h:1395
void eval_gpu(const std::vector< array > &inputs, array &out) override
Negative(Stream stream)
Definition primitives.h:1397
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1412
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
NotEqual(Stream stream)
Definition primitives.h:1414
Definition primitives.h:1429
void eval_gpu(const std::vector< array > &inputs, array &out) override
NumberOfElements(Stream stream, std::vector< int > axes, bool inverted, Dtype dtype)
Definition primitives.h:1431
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1460
void eval_cpu(const std::vector< array > &inputs, array &out) override
Pad(Stream stream, const std::vector< int > &axes, const std::vector< int > &low_pad_size, const std::vector< int > &high_pad_size)
Definition primitives.h:1462
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1488
void eval_cpu(const std::vector< array > &inputs, array &out) override
Partition(Stream stream, int kth, int axis)
Definition primitives.h:1490
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1509
void eval_cpu(const std::vector< array > &inputs, array &out) override
Power(Stream stream)
Definition primitives.h:1511
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:48
virtual void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
virtual std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
The vector-Jacobian product.
virtual ~Primitive()=default
Primitive(const Primitive &other)=delete
Primitive(Primitive &&other)=delete
const Stream & stream()
The stream the primitive will run on.
Definition primitives.h:58
Primitive & operator=(Primitive &&other)=delete
virtual bool is_equivalent(const Primitive &other) const
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:107
Primitive & operator=(const Primitive &other)=delete
virtual std::vector< std::vector< int > > output_shapes(const std::vector< array > &inputs)
Get the output shapes of the primitive.
const Device & device()
The device the primitive will run on.
Definition primitives.h:53
virtual std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
The Jacobian-vector product.
virtual std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes)
The primitive must know how to vectorize itself across the given axes.
virtual void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
virtual void print(std::ostream &os)=0
Print the primitive.
Primitive(Stream stream)
Definition primitives.h:50
Definition primitives.h:2152
QRF(Stream stream)
Definition primitives.h:2154
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:1526
void eval_gpu(const std::vector< array > &inputs, array &out) override
QuantizedMatmul(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1528
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1578
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
RandomBits(Stream stream, const std::vector< int > &shape, int width)
Definition primitives.h:1580
Definition primitives.h:1597
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Real(Stream stream)
Definition primitives.h:1599
Definition primitives.h:1638
Reduce(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1642
ReduceType
Definition primitives.h:1640
@ And
Definition primitives.h:1640
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:882
Remainder(Stream stream)
Definition primitives.h:884
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1611
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Reshape(Stream stream, const std::vector< int > &shape)
Definition primitives.h:1613
Definition primitives.h:1693
Round(Stream stream)
Definition primitives.h:1695
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2168
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
SVD(Stream stream)
Definition primitives.h:2170
Definition primitives.h:1710
void eval_cpu(const std::vector< array > &inputs, array &out) override
ReduceType
Definition primitives.h:1712
@ Max
Definition primitives.h:1712
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Scan(Stream stream, ReduceType reduce_type, int axis, bool reverse, bool inclusive)
Definition primitives.h:1714
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1760
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
ReduceType
Definition primitives.h:1762
@ Max
Definition primitives.h:1762
void eval_cpu(const std::vector< array > &inputs, array &out) override
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1776
void eval_gpu(const std::vector< array > &inputs, array &out) override
Scatter(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1764
Definition primitives.h:865
void eval_gpu(const std::vector< array > &inputs, array &out) override
Select(Stream stream)
Definition primitives.h:867
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1803
Sigmoid(Stream stream)
Definition primitives.h:1805
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1820
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Sign(Stream stream)
Definition primitives.h:1822
Definition primitives.h:1837
Sin(Stream stream)
Definition primitives.h:1839
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1854
Sinh(Stream stream)
Definition primitives.h:1856
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1871
void eval_cpu(const std::vector< array > &inputs, array &out) override
Slice(Stream stream, const std::vector< int > &start_indices, const std::vector< int > &end_indices, const std::vector< int > &strides)
Definition primitives.h:1873
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1899
SliceUpdate(Stream stream, const std::vector< int > &start_indices, const std::vector< int > &end_indices, const std::vector< int > &strides)
Definition primitives.h:1901
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1929
void eval_gpu(const std::vector< array > &inputs, array &out) override
Softmax(Stream stream, bool precise)
Definition primitives.h:1931
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1949
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sort(Stream stream, int axis)
Definition primitives.h:1951
Definition primitives.h:1969
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Split(Stream stream, const std::vector< int > &indices, int axis)
Definition primitives.h:1971
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:2008
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sqrt(Stream stream, bool recip=false)
Definition primitives.h:2010
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1991
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Square(Stream stream)
Definition primitives.h:1993
Definition primitives.h:2034
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
StopGradient(Stream stream)
Definition primitives.h:2036
Definition primitives.h:2050
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Subtract(Stream stream)
Definition primitives.h:2052
Definition primitives.h:2067
Tan(Stream stream)
Definition primitives.h:2069
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2084
void eval_gpu(const std::vector< array > &inputs, array &out) override
Tanh(Stream stream)
Definition primitives.h:2086
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2132
Transpose(Stream stream, const std::vector< int > &axes)
Definition primitives.h:2134
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:127
UnaryPrimitive & operator=(const UnaryPrimitive &other)=delete
UnaryPrimitive(Stream stream)
An abstract base class for a primitive with a single output.
Definition primitives.h:132
virtual void eval_gpu(const std::vector< array > &inputs, array &output)=0
UnaryPrimitive(UnaryPrimitive &&other)=delete
virtual void eval_cpu(const std::vector< array > &inputs, array &output)=0
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:142
UnaryPrimitive(const UnaryPrimitive &other)=delete
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:137
UnaryPrimitive & operator=(UnaryPrimitive &&other)=delete
virtual ~UnaryPrimitive()=default
Definition primitives.h:2101
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Uniform(Stream stream)
Definition primitives.h:2103
Definition primitives.h:2116
void eval_cpu(const std::vector< array > &inputs, array &out) override
View(Stream stream, Dtype dtype)
Definition primitives.h:2118
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition array.h:20
Op op
Definition binary.h:129
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
array tri(int n, int m, int k, Dtype type, StreamOrDevice s={})
array transpose(const array &a, std::vector< int > axes, StreamOrDevice s={})
Permutes the dimensions according to the given axes.
array real(const array &a, StreamOrDevice s={})
Definition allocator.h:7
std::pair< std::vector< array >, std::vector< array > > jvp(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &tangents)
Computes the output and Jacobian-vector product (JVP) of a function.
std::pair< std::vector< array >, std::vector< array > > vjp(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &cotangents)
Computes the output and vector-Jacobian product (VJP) of a function.
void eval(std::vector< array > outputs)
std::function< array(const array &)> vmap(const std::function< array(const array &)> &fun, int in_axis=0, int out_axis=0)
Automatically vectorize a unary function over the requested axes.
#define DEFINE_DEFAULT_IS_EQUIVALENT()
Definition primitives.h:34
#define DEFINE_PRINT(PRIMITIVE)
Definition primitives.h:29
#define DEFINE_INPUT_OUTPUT_SHAPE()
Definition primitives.h:39
#define DEFINE_GRADS()
Definition primitives.h:17
#define DEFINE_VMAP()
Definition primitives.h:12
Definition ops.h:37
Definition binary_ops.h:270
Definition ops.h:185
Definition ops.h:163
Definition ops.h:29
Definition ops.h:78
Definition ops.h:141
Definition binary_ops.h:277
Definition ops.h:119
Definition device.h:7
Definition dtype.h:13
Definition stream.h:9
Device device
Definition stream.h:11