MLX
Loading...
Searching...
No Matches
primitives.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <unordered_set>
6
7#include "mlx/array.h"
8#include "mlx/device.h"
9#include "mlx/io/load.h"
10#include "mlx/stream.h"
11
12#define DEFINE_VMAP() \
13 virtual std::pair<std::vector<array>, std::vector<int>> vmap( \
14 const std::vector<array>& inputs, const std::vector<int>& axes) \
15 override;
16
17#define DEFINE_GRADS() \
18 std::vector<array> jvp( \
19 const std::vector<array>& primals, \
20 const std::vector<array>& tangents, \
21 const std::vector<int>& argnums) override; \
22 \
23 std::vector<array> vjp( \
24 const std::vector<array>& primals, \
25 const std::vector<array>& cotangents, \
26 const std::vector<int>& argnums, \
27 const std::vector<array>& outputs) override;
28
29#define DEFINE_PRINT(PRIMITIVE) \
30 void print(std::ostream& os) override { \
31 os << #PRIMITIVE; \
32 }
33
34#define DEFINE_DEFAULT_IS_EQUIVALENT() \
35 bool is_equivalent(const Primitive& other) const override { \
36 return true; \
37 }
38
39#define DEFINE_INPUT_OUTPUT_SHAPE() \
40 std::vector<std::vector<int>> output_shapes( \
41 const std::vector<array>& inputs) override { \
42 return {inputs[0].shape()}; \
43 }
44
45namespace mlx::core {
46
47// Abstract base class
48class Primitive {
49 public:
50 explicit Primitive(Stream stream) : stream_(stream) {}
51
53 const Device& device() {
54 return stream().device;
55 }
56
58 const Stream& stream() {
59 return stream_;
60 }
61
69 virtual void eval_cpu(
70 const std::vector<array>& inputs,
71 std::vector<array>& outputs) = 0;
72 virtual void eval_gpu(
73 const std::vector<array>& inputs,
74 std::vector<array>& outputs) = 0;
75
79 virtual std::vector<array> jvp(
80 const std::vector<array>& primals,
81 const std::vector<array>& tangents,
82 const std::vector<int>& argnums);
83
87 virtual std::vector<array> vjp(
88 const std::vector<array>& primals,
89 const std::vector<array>& cotangents,
90 const std::vector<int>& argnums,
91 const std::vector<array>& outputs);
92
99 virtual std::pair<std::vector<array>, std::vector<int>> vmap(
100 const std::vector<array>& inputs,
101 const std::vector<int>& axes);
102
104 virtual void print(std::ostream& os) = 0;
105
107 virtual bool is_equivalent(const Primitive& other) const {
108 return false;
109 }
110
113 virtual std::vector<std::vector<int>> output_shapes(
114 const std::vector<array>& inputs);
115
116 virtual ~Primitive() = default;
117 Primitive(const Primitive& other) = delete;
118 Primitive(Primitive&& other) = delete;
119 Primitive& operator=(const Primitive& other) = delete;
120 Primitive& operator=(Primitive&& other) = delete;
121
122 private:
123 // Every primitive stores the stream it should run in
124 Stream stream_;
125};
126
127class UnaryPrimitive : public Primitive {
131 public:
133
134 virtual void eval_cpu(const std::vector<array>& inputs, array& output) = 0;
135 virtual void eval_gpu(const std::vector<array>& inputs, array& output) = 0;
136
137 inline void eval_cpu(
138 const std::vector<array>& inputs,
139 std::vector<array>& outputs) override {
140 eval_cpu(inputs, outputs[0]);
141 }
142 inline void eval_gpu(
143 const std::vector<array>& inputs,
144 std::vector<array>& outputs) override {
145 eval_gpu(inputs, outputs[0]);
146 }
147
148 virtual ~UnaryPrimitive() = default;
149 UnaryPrimitive(const UnaryPrimitive& other) = delete;
151 UnaryPrimitive& operator=(const UnaryPrimitive& other) = delete;
153};
154
155class Abs : public UnaryPrimitive {
156 public:
158
159 void eval_cpu(const std::vector<array>& inputs, array& out) override;
160 void eval_gpu(const std::vector<array>& inputs, array& out) override;
161
167
168 private:
169 void eval(const std::vector<array>& inputs, array& out);
170};
171
172class Add : public UnaryPrimitive {
173 public:
175
176 void eval_cpu(const std::vector<array>& inputs, array& out) override;
177 void eval_gpu(const std::vector<array>& inputs, array& out) override;
178
184
185 private:
186 void eval(const std::vector<array>& inputs, array& out);
187};
188
189class AddMM : public UnaryPrimitive {
190 public:
191 explicit AddMM(Stream stream, float alpha, float beta)
192 : UnaryPrimitive(stream), alpha_(alpha), beta_(beta) {}
193
194 void eval_cpu(const std::vector<array>& inputs, array& out) override;
195 void eval_gpu(const std::vector<array>& inputs, array& out) override;
196
197 std::vector<array> vjp(
198 const std::vector<array>& primals,
199 const std::vector<array>& cotangents,
200 const std::vector<int>& argnums,
201 const std::vector<array>& outputs) override;
202
205
206 bool is_equivalent(const Primitive& other) const override;
207
208 private:
209 const float alpha_;
210 const float beta_;
211};
212
213class Arange : public UnaryPrimitive {
214 public:
215 explicit Arange(Stream stream, double start, double stop, double step)
216 : UnaryPrimitive(stream), start_(start), stop_(stop), step_(step) {}
217
218 void eval_cpu(const std::vector<array>& inputs, array& out) override;
219 void eval_gpu(const std::vector<array>& inputs, array& out) override;
220
222 bool is_equivalent(const Primitive& other) const override;
223
224 private:
225 double start_;
226 double stop_;
227 double step_;
228
229 void eval(const std::vector<array>& inputs, array& out);
230};
231
232class ArcCos : public UnaryPrimitive {
233 public:
235
236 void eval_cpu(const std::vector<array>& inputs, array& out) override;
237 void eval_gpu(const std::vector<array>& inputs, array& out) override;
238
244
245 private:
246 void eval(const std::vector<array>& inputs, array& out);
247};
248
249class ArcCosh : public UnaryPrimitive {
250 public:
252
253 void eval_cpu(const std::vector<array>& inputs, array& out) override;
254 void eval_gpu(const std::vector<array>& inputs, array& out) override;
255
261
262 private:
263 void eval(const std::vector<array>& inputs, array& out);
264};
265
266class ArcSin : public UnaryPrimitive {
267 public:
269
270 void eval_cpu(const std::vector<array>& inputs, array& out) override;
271 void eval_gpu(const std::vector<array>& inputs, array& out) override;
272
278
279 private:
280 void eval(const std::vector<array>& inputs, array& out);
281};
282
283class ArcSinh : public UnaryPrimitive {
284 public:
286
287 void eval_cpu(const std::vector<array>& inputs, array& out) override;
288 void eval_gpu(const std::vector<array>& inputs, array& out) override;
289
295
296 private:
297 void eval(const std::vector<array>& inputs, array& out);
298};
299
300class ArcTan : public UnaryPrimitive {
301 public:
303
304 void eval_cpu(const std::vector<array>& inputs, array& out) override;
305 void eval_gpu(const std::vector<array>& inputs, array& out) override;
306
312
313 private:
314 void eval(const std::vector<array>& inputs, array& out);
315};
316
317class ArcTan2 : public UnaryPrimitive {
318 public:
320
321 void eval_cpu(const std::vector<array>& inputs, array& out) override;
322 void eval_gpu(const std::vector<array>& inputs, array& out) override;
323
329
330 private:
331 void eval(const std::vector<array>& inputs, array& out);
332};
333
334class ArcTanh : public UnaryPrimitive {
335 public:
337
338 void eval_cpu(const std::vector<array>& inputs, array& out) override;
339 void eval_gpu(const std::vector<array>& inputs, array& out) override;
340
346
347 private:
348 void eval(const std::vector<array>& inputs, array& out);
349};
350
352 public:
353 explicit ArgPartition(Stream stream, int kth, int axis)
354 : UnaryPrimitive(stream), kth_(kth), axis_(axis) {}
355
356 void eval_cpu(const std::vector<array>& inputs, array& out) override;
357 void eval_gpu(const std::vector<array>& inputs, array& out) override;
358
362 bool is_equivalent(const Primitive& other) const override;
363
364 private:
365 int kth_;
366 int axis_;
367
368 void eval(const std::vector<array>& inputs, array& out);
369};
370
371class ArgReduce : public UnaryPrimitive {
372 public:
377
378 explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis)
379 : UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {}
380
381 void eval_cpu(const std::vector<array>& inputs, array& out) override;
382 void eval_gpu(const std::vector<array>& inputs, array& out) override;
383
386 bool is_equivalent(const Primitive& other) const override;
387 std::vector<std::vector<int>> output_shapes(
388 const std::vector<array>& inputs) override;
389
390 private:
391 ReduceType reduce_type_;
392 int axis_;
393
394 void eval(const std::vector<array>& inputs, array& out);
395};
396
397class ArgSort : public UnaryPrimitive {
398 public:
399 explicit ArgSort(Stream stream, int axis)
400 : UnaryPrimitive(stream), axis_(axis) {}
401
402 void eval_cpu(const std::vector<array>& inputs, array& out) override;
403 void eval_gpu(const std::vector<array>& inputs, array& out) override;
404
408 bool is_equivalent(const Primitive& other) const override;
409
410 private:
411 int axis_;
412
413 void eval(const std::vector<array>& inputs, array& out);
414};
415
416class AsType : public UnaryPrimitive {
417 public:
418 explicit AsType(Stream stream, Dtype dtype)
419 : UnaryPrimitive(stream), dtype_(dtype) {}
420
421 void eval_cpu(const std::vector<array>& inputs, array& out) override;
422 void eval_gpu(const std::vector<array>& inputs, array& out) override;
423
428 bool is_equivalent(const Primitive& other) const override;
429
430 private:
431 Dtype dtype_;
432
433 void eval(const std::vector<array>& inputs, array& out);
434};
435
436class AsStrided : public UnaryPrimitive {
437 public:
438 explicit AsStrided(
440 std::vector<int> shape,
441 std::vector<size_t> strides,
442 size_t offset)
444 shape_(std::move(shape)),
445 strides_(std::move(strides)),
446 offset_(offset) {}
447
448 void eval_cpu(const std::vector<array>& inputs, array& out) override;
449 void eval_gpu(const std::vector<array>& inputs, array& out) override;
450
453 bool is_equivalent(const Primitive& other) const override;
454
455 private:
456 std::vector<int> shape_;
457 std::vector<size_t> strides_;
458 size_t offset_;
459
460 void eval(const std::vector<array>& inputs, array& out);
461};
462
464 public:
465 enum Op { And, Or, Xor, LeftShift, RightShift };
466
469
470 void eval_cpu(const std::vector<array>& inputs, array& out) override;
471 void eval_gpu(const std::vector<array>& inputs, array& out) override;
472
475 bool is_equivalent(const Primitive& other) const override;
476 void print(std::ostream& os) override;
478
479 private:
480 Op op_;
481};
482
484 public:
485 explicit BlockMaskedMM(Stream stream, int block_size)
486 : UnaryPrimitive(stream), block_size_(block_size) {}
487
488 void eval_cpu(const std::vector<array>& inputs, array& out) override;
489 void eval_gpu(const std::vector<array>& inputs, array& out) override;
490
491 std::vector<array> vjp(
492 const std::vector<array>& primals,
493 const std::vector<array>& cotangents,
494 const std::vector<int>& argnums,
495 const std::vector<array>& outputs) override;
496
498 bool is_equivalent(const Primitive& other) const override;
499
500 private:
501 int block_size_;
502
503 void eval(const std::vector<array>& inputs, array& out);
504};
505
506class GatherMM : public UnaryPrimitive {
507 public:
509
510 void eval_cpu(const std::vector<array>& inputs, array& out) override;
511 void eval_gpu(const std::vector<array>& inputs, array& out) override;
512
513 std::vector<array> vjp(
514 const std::vector<array>& primals,
515 const std::vector<array>& cotangents,
516 const std::vector<int>& argnums,
517 const std::vector<array>& outputs) override;
518
521
522 private:
523 void eval(const std::vector<array>& inputs, array& out);
524};
525
526class Broadcast : public UnaryPrimitive {
527 public:
528 explicit Broadcast(Stream stream, const std::vector<int>& shape)
529 : UnaryPrimitive(stream), shape_(shape) {}
530
531 void eval_cpu(const std::vector<array>& inputs, array& out) override;
532 void eval_gpu(const std::vector<array>& inputs, array& out) override;
533
537 bool is_equivalent(const Primitive& other) const override;
538
539 private:
540 std::vector<int> shape_;
541
542 void eval(const std::vector<array>& inputs, array& out);
543};
544
545class Ceil : public UnaryPrimitive {
546 public:
548
549 void eval_cpu(const std::vector<array>& inputs, array& out) override;
550 void eval_gpu(const std::vector<array>& inputs, array& out) override;
551
557
558 private:
559 void eval(const std::vector<array>& inputs, array& out);
560};
561
562class Compiled : public Primitive {
563 public:
564 /*
565 * The inputs, outputs and tape are either tracers or constants.
566 * - The tape should not contain the inputs, but it should contain the
567 * outputs.
568 * - The tape should also have only one array per primitive for multi-output
569 * primitives.
570 * - The constant_ids contains ids of arrays in the input list that are safe
571 * to treat as scalar constants.
572 */
573 explicit Compiled(
575 std::vector<array> inputs,
576 std::vector<array> outputs,
577 std::vector<array> tape,
578 std::unordered_set<uintptr_t> constant_ids);
579
580 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
581 override;
582 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
583 override;
584
587 std::vector<std::vector<int>> output_shapes(
588 const std::vector<array>& inputs) override;
589 void print(std::ostream& os) override;
590 bool is_equivalent(const Primitive& other) const override;
591
592 std::string lib_name() const {
593 return kernel_lib_;
594 }
595
596 private:
597 const std::vector<array> inputs_;
598 const std::vector<array> outputs_;
599 const std::vector<array> tape_;
600 const std::unordered_set<uintptr_t> constant_ids_;
601
602 std::string kernel_lib_;
603};
604
606 public:
607 explicit Concatenate(Stream stream, int axis)
608 : UnaryPrimitive(stream), axis_(axis) {}
609
610 void eval_cpu(const std::vector<array>& inputs, array& out) override;
611 void eval_gpu(const std::vector<array>& inputs, array& out) override;
612
616 bool is_equivalent(const Primitive& other) const override;
617
618 private:
619 int axis_;
620
621 void eval(const std::vector<array>& inputs, array& out);
622};
623
624class Conjugate : public UnaryPrimitive {
625 public:
626 explicit Conjugate(Stream stream) : UnaryPrimitive(stream) {}
627
628 void eval_cpu(const std::vector<array>& inputs, array& out) override;
629 void eval_gpu(const std::vector<array>& inputs, array& out) override;
630
635
636 private:
637 void eval(const std::vector<array>& inputs, array& out);
638};
639
641 public:
642 explicit Convolution(
643 Stream stream,
644 const std::vector<int>& kernel_strides,
645 const std::vector<int>& padding,
646 const std::vector<int>& kernel_dilation,
647 const std::vector<int>& input_dilation,
648 const int groups = 1,
649 const bool flip = false)
650 : UnaryPrimitive(stream),
651 padding_(padding),
652 kernel_strides_(kernel_strides),
653 kernel_dilation_(kernel_dilation),
654 input_dilation_(input_dilation),
655 groups_(groups),
656 flip_(flip) {}
657
658 void eval_cpu(const std::vector<array>& inputs, array& out) override;
659 void eval_gpu(const std::vector<array>& inputs, array& out) override;
660
661 std::vector<array> vjp(
662 const std::vector<array>& primals,
663 const std::vector<array>& cotangents,
664 const std::vector<int>& argnums,
665 const std::vector<array>& outputs) override;
666
668 bool is_equivalent(const Primitive& other) const override;
669
670 private:
671 std::vector<int> padding_;
672 std::vector<int> kernel_strides_;
673 std::vector<int> kernel_dilation_;
674 std::vector<int> input_dilation_;
675 int groups_;
676 bool flip_;
677
678 void eval(const std::vector<array>& inputs, array& out);
679};
680
681class Copy : public UnaryPrimitive {
682 public:
683 explicit Copy(Stream stream) : UnaryPrimitive(stream) {}
684
685 void eval_cpu(const std::vector<array>& inputs, array& out) override;
686 void eval_gpu(const std::vector<array>& inputs, array& out) override;
687
693
694 private:
695 void eval(const std::vector<array>& inputs, array& out);
696};
697
698class Cos : public UnaryPrimitive {
699 public:
700 explicit Cos(Stream stream) : UnaryPrimitive(stream) {}
701
702 void eval_cpu(const std::vector<array>& inputs, array& out) override;
703 void eval_gpu(const std::vector<array>& inputs, array& out) override;
704
710
711 private:
712 void eval(const std::vector<array>& inputs, array& out);
713};
714
715class Cosh : public UnaryPrimitive {
716 public:
717 explicit Cosh(Stream stream) : UnaryPrimitive(stream) {}
718
719 void eval_cpu(const std::vector<array>& inputs, array& out) override;
720 void eval_gpu(const std::vector<array>& inputs, array& out) override;
721
727
728 private:
729 void eval(const std::vector<array>& inputs, array& out);
730};
731
733 public:
735 Stream stream,
736 int num_outputs,
737 std::function<std::vector<array>(
738 const std::vector<array>&,
739 const std::vector<array>&,
740 const std::vector<array>&)> vjp,
741 std::function<std::vector<array>(
742 const std::vector<array>&,
743 const std::vector<array>&,
744 const std::vector<int>&)> jvp,
745 std::function<std::pair<std::vector<array>, std::vector<int>>(
746 const std::vector<array>&,
747 const std::vector<int>&)> vmap)
748 : Primitive(stream),
749 num_outputs_(num_outputs),
750 vjp_fun_(std::move(vjp)),
751 jvp_fun_(std::move(jvp)),
752 vmap_fun_(std::move(vmap)) {}
753
754 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
755 override;
756 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
757 override;
758
762
763 private:
764 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
765
766 int num_outputs_;
767
768 std::function<std::vector<array>(
769 const std::vector<array>&,
770 const std::vector<array>&,
771 const std::vector<array>&)>
772 vjp_fun_;
773 std::function<std::vector<array>(
774 const std::vector<array>&,
775 const std::vector<array>&,
776 const std::vector<int>&)>
777 jvp_fun_;
778 std::function<std::pair<std::vector<array>, std::vector<int>>(
779 const std::vector<array>&,
780 const std::vector<int>&)>
781 vmap_fun_;
782};
783
784class Depends : public Primitive {
785 public:
786 explicit Depends(Stream stream) : Primitive(stream) {}
787
788 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
789 override;
790 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
791 override;
792
793 std::vector<array> vjp(
794 const std::vector<array>& primals,
795 const std::vector<array>& cotan,
796 const std::vector<int>& argnums,
797 const std::vector<array>& outputs) override;
798
800
801 private:
802 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
803};
804
805class Divide : public UnaryPrimitive {
806 public:
807 explicit Divide(Stream stream) : UnaryPrimitive(stream) {}
808
809 void eval_cpu(const std::vector<array>& inputs, array& out) override;
810 void eval_gpu(const std::vector<array>& inputs, array& out) override;
811
817
818 private:
819 void eval(const std::vector<array>& inputs, array& out);
820};
821
822class DivMod : public Primitive {
823 public:
824 explicit DivMod(Stream stream) : Primitive(stream) {}
825
826 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
827 override;
828 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
829 override;
830
835 std::vector<std::vector<int>> output_shapes(
836 const std::vector<array>& inputs) override {
837 return std::vector{inputs[0].shape(), inputs[0].shape()};
838 }
839
840 private:
841 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
842};
843
844class Select : public UnaryPrimitive {
845 public:
846 explicit Select(Stream stream) : UnaryPrimitive(stream) {}
847
848 void eval_cpu(const std::vector<array>& inputs, array& out) override;
849 void eval_gpu(const std::vector<array>& inputs, array& out) override;
850
856
857 private:
858 void eval(const std::vector<array>& inputs, array& out);
859};
860
861class Remainder : public UnaryPrimitive {
862 public:
863 explicit Remainder(Stream stream) : UnaryPrimitive(stream) {}
864
865 void eval_cpu(const std::vector<array>& inputs, array& out) override;
866 void eval_gpu(const std::vector<array>& inputs, array& out) override;
867
873
874 private:
875 void eval(const std::vector<array>& inputs, array& out);
876};
877
878class Equal : public UnaryPrimitive {
879 public:
880 explicit Equal(Stream stream, bool equal_nan = false)
881 : UnaryPrimitive(stream), equal_nan_(equal_nan) {}
882
883 void eval_cpu(const std::vector<array>& inputs, array& out) override;
884 void eval_gpu(const std::vector<array>& inputs, array& out) override;
885
890
891 void print(std::ostream& os) override {
892 if (equal_nan_) {
893 os << "NaNEqual";
894 } else {
895 os << "Equal";
896 }
897 }
898
899 private:
900 void eval(const std::vector<array>& inputs, array& out);
901 bool equal_nan_;
902};
903
904class Erf : public UnaryPrimitive {
905 public:
906 explicit Erf(Stream stream) : UnaryPrimitive(stream) {}
907
908 void eval_cpu(const std::vector<array>& inputs, array& out) override;
909 void eval_gpu(const std::vector<array>& inputs, array& out) override;
910
916
917 private:
918 void eval(const std::vector<array>& inputs, array& out);
919};
920
921class ErfInv : public UnaryPrimitive {
922 public:
923 explicit ErfInv(Stream stream) : UnaryPrimitive(stream) {}
924
925 void eval_cpu(const std::vector<array>& inputs, array& out) override;
926 void eval_gpu(const std::vector<array>& inputs, array& out) override;
927
933
934 private:
935 void eval(const std::vector<array>& inputs, array& out);
936};
937
938class Exp : public UnaryPrimitive {
939 public:
940 explicit Exp(Stream stream) : UnaryPrimitive(stream) {}
941
942 void eval_cpu(const std::vector<array>& inputs, array& out) override;
943 void eval_gpu(const std::vector<array>& inputs, array& out) override;
944
950
951 private:
952 void eval(const std::vector<array>& inputs, array& out);
953};
954
955class Expm1 : public UnaryPrimitive {
956 public:
957 explicit Expm1(Stream stream) : UnaryPrimitive(stream) {}
958
959 void eval_cpu(const std::vector<array>& inputs, array& out) override;
960 void eval_gpu(const std::vector<array>& inputs, array& out) override;
961
966
967 private:
968 void eval(const std::vector<array>& inputs, array& out);
969};
970
971class FFT : public UnaryPrimitive {
972 public:
973 explicit FFT(
974 Stream stream,
975 const std::vector<size_t>& axes,
976 bool inverse,
977 bool real)
978 : UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {}
979
980 void eval_cpu(const std::vector<array>& inputs, array& out) override;
981 void eval_gpu(const std::vector<array>& inputs, array& out) override;
982
986
987 bool is_equivalent(const Primitive& other) const override;
988
989 private:
990 std::vector<size_t> axes_;
991 bool inverse_;
992 bool real_;
993
994 void eval(const std::vector<array>& inputs, array& out);
995};
996
997class Floor : public UnaryPrimitive {
998 public:
999 explicit Floor(Stream stream) : UnaryPrimitive(stream) {}
1000
1001 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1002 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1003
1009
1010 private:
1011 void eval(const std::vector<array>& inputs, array& out);
1012};
1013
1014class Full : public UnaryPrimitive {
1015 public:
1016 explicit Full(Stream stream) : UnaryPrimitive(stream) {}
1017
1018 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1019 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1020
1025
1026 private:
1027 void eval(const std::vector<array>& inputs, array& out);
1028};
1029
1030class Gather : public UnaryPrimitive {
1031 public:
1032 explicit Gather(
1033 Stream stream,
1034 const std::vector<int>& axes,
1035 const std::vector<int>& slice_sizes)
1036 : UnaryPrimitive(stream), axes_(axes), slice_sizes_(slice_sizes) {}
1037
1038 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1039 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1040
1044 bool is_equivalent(const Primitive& other) const override;
1045
1046 private:
1047 void eval(const std::vector<array>& inputs, array& out);
1048 std::vector<int> axes_;
1049 std::vector<int> slice_sizes_;
1050};
1051
1052class Greater : public UnaryPrimitive {
1053 public:
1054 explicit Greater(Stream stream) : UnaryPrimitive(stream) {}
1055
1056 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1057 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1058
1064
1065 private:
1066 void eval(const std::vector<array>& inputs, array& out);
1067};
1068
1070 public:
1071 explicit GreaterEqual(Stream stream) : UnaryPrimitive(stream) {}
1072
1073 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1074 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1075
1081
1082 private:
1083 void eval(const std::vector<array>& inputs, array& out);
1084};
1085
1086class Hadamard : public UnaryPrimitive {
1087 public:
1088 explicit Hadamard(Stream stream, float scale)
1089 : UnaryPrimitive(stream), scale_(scale) {}
1090
1091 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1092 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1093
1098
1099 bool is_equivalent(const Primitive& other) const override;
1100
1101 private:
1102 float scale_;
1103
1104 void eval(const std::vector<array>& inputs, array& out);
1105};
1106
1107class Less : public UnaryPrimitive {
1108 public:
1109 explicit Less(Stream stream) : UnaryPrimitive(stream) {}
1110
1111 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1112 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1113
1119
1120 private:
1121 void eval(const std::vector<array>& inputs, array& out);
1122};
1123
1125 public:
1126 explicit LessEqual(Stream stream) : UnaryPrimitive(stream) {}
1127
1128 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1129 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1130
1136
1137 private:
1138 void eval(const std::vector<array>& inputs, array& out);
1139};
1140
1141class Load : public UnaryPrimitive {
1142 public:
1143 explicit Load(
1144 Stream stream,
1145 std::shared_ptr<io::Reader> reader,
1146 size_t offset,
1147 bool swap_endianness = false)
1148 : UnaryPrimitive(stream),
1149 reader_(std::move(reader)),
1150 offset_(offset),
1151 swap_endianness_(swap_endianness) {
1152 if (stream.device == Device::gpu) {
1153 io_stream();
1154 }
1155 }
1156
1157 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1158 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1159
1161
1162 private:
1163 Stream& io_stream() {
1164 static Stream io_stream = new_stream(Device::cpu);
1165 return io_stream;
1166 };
1167 void eval(const std::vector<array>& inputs, array& out);
1168 std::shared_ptr<io::Reader> reader_;
1169 size_t offset_;
1170 bool swap_endianness_;
1171};
1172
1173class Log : public UnaryPrimitive {
1174 public:
1175 enum Base { two, ten, e };
1176
1177 explicit Log(Stream stream, Base base)
1178 : UnaryPrimitive(stream), base_(base) {}
1179
1180 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1181 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1182
1187
1188 void print(std::ostream& os) override {
1189 switch (base_) {
1190 case e:
1191 os << "Log";
1192 break;
1193 case two:
1194 os << "Log2";
1195 break;
1196 case ten:
1197 os << "Log10";
1198 break;
1199 }
1200 }
1201
1202 private:
1203 Base base_;
1204 void eval(const std::vector<array>& inputs, array& out);
1205};
1206
1207class Log1p : public UnaryPrimitive {
1208 public:
1209 explicit Log1p(Stream stream) : UnaryPrimitive(stream) {}
1210
1211 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1212 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1213
1218
1219 private:
1220 void eval(const std::vector<array>& inputs, array& out);
1221};
1222
1224 public:
1225 explicit LogicalNot(Stream stream) : UnaryPrimitive(stream) {}
1226
1227 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1228 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1229
1235
1236 private:
1237 void eval(const std::vector<array>& inputs, array& out);
1238};
1239
1241 public:
1242 explicit LogicalAnd(Stream stream) : UnaryPrimitive(stream) {}
1243
1244 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1245 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1246
1252
1253 private:
1254 void eval(const std::vector<array>& inputs, array& out);
1255};
1256
1258 public:
1259 explicit LogicalOr(Stream stream) : UnaryPrimitive(stream) {}
1260
1261 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1262 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1263
1269
1270 private:
1271 void eval(const std::vector<array>& inputs, array& out);
1272};
1273
1275 public:
1276 explicit LogAddExp(Stream stream) : UnaryPrimitive(stream) {}
1277
1278 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1279 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1280
1286
1287 private:
1288 void eval(const std::vector<array>& inputs, array& out);
1289};
1290
1291class Matmul : public UnaryPrimitive {
1292 public:
1293 explicit Matmul(Stream stream) : UnaryPrimitive(stream) {}
1294
1295 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1296 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1297
1298 std::vector<array> vjp(
1299 const std::vector<array>& primals,
1300 const std::vector<array>& cotangents,
1301 const std::vector<int>& argnums,
1302 const std::vector<array>& outputs) override;
1303
1307};
1308
1309class Maximum : public UnaryPrimitive {
1310 public:
1311 explicit Maximum(Stream stream) : UnaryPrimitive(stream) {}
1312
1313 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1314 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1315
1321
1322 private:
1323 void eval(const std::vector<array>& inputs, array& out);
1324};
1325
1326class Minimum : public UnaryPrimitive {
1327 public:
1328 explicit Minimum(Stream stream) : UnaryPrimitive(stream) {}
1329
1330 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1331 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1332
1338
1339 private:
1340 void eval(const std::vector<array>& inputs, array& out);
1341};
1342
1343class Multiply : public UnaryPrimitive {
1344 public:
1345 explicit Multiply(Stream stream) : UnaryPrimitive(stream) {}
1346
1347 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1348 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1349
1355
1356 private:
1357 void eval(const std::vector<array>& inputs, array& out);
1358};
1359
1360class Negative : public UnaryPrimitive {
1361 public:
1362 explicit Negative(Stream stream) : UnaryPrimitive(stream) {}
1363
1364 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1365 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1366
1372
1373 private:
1374 void eval(const std::vector<array>& inputs, array& out);
1375};
1376
1377class NotEqual : public UnaryPrimitive {
1378 public:
1379 explicit NotEqual(Stream stream) : UnaryPrimitive(stream) {}
1380
1381 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1382 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1383
1389
1390 private:
1391 void eval(const std::vector<array>& inputs, array& out);
1392};
1393
1395 public:
1397 Stream stream,
1398 std::vector<int> axes,
1399 bool inverted,
1400 Dtype dtype)
1401 : UnaryPrimitive(stream),
1402 axes_(std::move(axes)),
1403 inverted_(inverted),
1404 dtype_(dtype) {}
1405
1406 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1407 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1408
1411 bool is_equivalent(const Primitive& other) const override;
1412 std::vector<std::vector<int>> output_shapes(
1413 const std::vector<array>& inputs) override {
1414 return {{}};
1415 }
1416
1417 private:
1418 std::vector<int> axes_;
1419 bool inverted_;
1420 Dtype dtype_;
1421
1422 void eval(const std::vector<array>& inputs, array& out);
1423};
1424
1425class Pad : public UnaryPrimitive {
1426 public:
1427 explicit Pad(
1428 Stream stream,
1429 const std::vector<int>& axes,
1430 const std::vector<int>& low_pad_size,
1431 const std::vector<int>& high_pad_size)
1432 : UnaryPrimitive(stream),
1433 axes_(axes),
1434 low_pad_size_(low_pad_size),
1435 high_pad_size_(high_pad_size) {}
1436
1437 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1438 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1439
1443 bool is_equivalent(const Primitive& other) const override;
1444
1445 private:
1446 std::vector<int> axes_;
1447 std::vector<int> low_pad_size_;
1448 std::vector<int> high_pad_size_;
1449
1450 void eval(const std::vector<array>& inputs, array& out);
1451};
1452
1454 public:
1455 explicit Partition(Stream stream, int kth, int axis)
1456 : UnaryPrimitive(stream), kth_(kth), axis_(axis) {}
1457
1458 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1459 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1460
1465 bool is_equivalent(const Primitive& other) const override;
1466
1467 private:
1468 int kth_;
1469 int axis_;
1470
1471 void eval(const std::vector<array>& inputs, array& out);
1472};
1473
1474class Power : public UnaryPrimitive {
1475 public:
1476 explicit Power(Stream stream) : UnaryPrimitive(stream) {}
1477
1478 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1479 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1480
1486
1487 private:
1488 void eval(const std::vector<array>& inputs, array& out);
1489};
1490
1492 public:
1494 Stream stream,
1495 int group_size,
1496 int bits,
1497 bool transpose)
1498 : UnaryPrimitive(stream),
1499 group_size_(group_size),
1500 bits_(bits),
1501 transpose_(transpose) {}
1502
1503 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1504 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1505
1509 bool is_equivalent(const Primitive& other) const override;
1510
1511 private:
1512 int group_size_;
1513 int bits_;
1514 bool transpose_;
1515
1516 void eval(const std::vector<array>& inputs, array& out);
1517};
1518
1520 public:
1521 explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose)
1522 : UnaryPrimitive(stream),
1523 group_size_(group_size),
1524 bits_(bits),
1525 transpose_(transpose) {}
1526
1527 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1528 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1529
1533 bool is_equivalent(const Primitive& other) const override;
1534
1535 private:
1536 int group_size_;
1537 int bits_;
1538 bool transpose_;
1539
1540 void eval(const std::vector<array>& inputs, array& out);
1541};
1542
1544 public:
1545 explicit RandomBits(Stream stream, const std::vector<int>& shape, int width)
1546 : UnaryPrimitive(stream), shape_(shape), width_(width) {}
1547
1548 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1549 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1550
1553 bool is_equivalent(const Primitive& other) const override;
1554
1555 private:
1556 std::vector<int> shape_;
1557 int width_;
1558
1559 void eval(const std::vector<array>& inputs, array& out);
1560};
1561
1562class Reshape : public UnaryPrimitive {
1563 public:
1564 explicit Reshape(Stream stream, const std::vector<int>& shape)
1565 : UnaryPrimitive(stream), shape_(shape) {}
1566
1567 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1568 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1569
1573 bool is_equivalent(const Primitive& other) const override;
1574
1575 private:
1576 std::vector<int> shape_;
1577
1578 void eval(const std::vector<array>& inputs, array& out);
1579
1580 std::pair<bool, std::vector<size_t>> prepare_reshape(
1581 const array& in,
1582 const array& out);
1583 void shared_buffer_reshape(
1584 const array& in,
1585 const std::vector<size_t>& out_strides,
1586 array& out);
1587};
1588
1589class Reduce : public UnaryPrimitive {
1590 public:
1591 enum ReduceType { And, Or, Sum, Prod, Min, Max };
1592
1593 explicit Reduce(
1594 Stream stream,
1595 ReduceType reduce_type,
1596 const std::vector<int>& axes)
1597 : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}
1598
1599 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1600 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1601
1603
1604 std::vector<array> vjp(
1605 const std::vector<array>& primals,
1606 const std::vector<array>& cotangents,
1607 const std::vector<int>& argnums,
1608 const std::vector<array>& outputs) override;
1609
1610 std::vector<std::vector<int>> output_shapes(
1611 const std::vector<array>& inputs) override;
1612
1613 void print(std::ostream& os) override {
1614 switch (reduce_type_) {
1615 case And:
1616 os << "And";
1617 break;
1618 case Or:
1619 os << "Or";
1620 break;
1621 case Sum:
1622 os << "Sum";
1623 break;
1624 case Prod:
1625 os << "Prod";
1626 break;
1627 case Min:
1628 os << "Min";
1629 break;
1630 case Max:
1631 os << "Max";
1632 break;
1633 }
1634 }
1635 bool is_equivalent(const Primitive& other) const override;
1636
1637 private:
1638 ReduceType reduce_type_;
1639 std::vector<int> axes_;
1640
1641 void eval(const std::vector<array>& inputs, array& out);
1642};
1643
1644class Round : public UnaryPrimitive {
1645 public:
1646 explicit Round(Stream stream) : UnaryPrimitive(stream) {}
1647
1648 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1649 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1650
1656
1657 private:
1658 void eval(const std::vector<array>& inputs, array& out);
1659};
1660
1661class Scan : public UnaryPrimitive {
1662 public:
1664
1665 explicit Scan(
1666 Stream stream,
1667 ReduceType reduce_type,
1668 int axis,
1669 bool reverse,
1670 bool inclusive)
1671 : UnaryPrimitive(stream),
1672 reduce_type_(reduce_type),
1673 axis_(axis),
1674 reverse_(reverse),
1675 inclusive_(inclusive) {}
1676
1677 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1678 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1679
1682
1683 void print(std::ostream& os) override {
1684 os << "Cum";
1685 switch (reduce_type_) {
1686 case Sum:
1687 os << "Sum";
1688 break;
1689 case Prod:
1690 os << "Prod";
1691 break;
1692 case Min:
1693 os << "Min";
1694 break;
1695 case Max:
1696 os << "Max";
1697 break;
1698 }
1699 }
1700 bool is_equivalent(const Primitive& other) const override;
1701
1702 private:
1703 ReduceType reduce_type_;
1704 int axis_;
1705 bool reverse_;
1706 bool inclusive_;
1707
1708 void eval(const std::vector<array>& inputs, array& out);
1709};
1710
1711class Scatter : public UnaryPrimitive {
1712 public:
1714
1715 explicit Scatter(
1716 Stream stream,
1717 ReduceType reduce_type,
1718 const std::vector<int>& axes)
1719 : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}
1720
1721 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1722 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1723
1726
1727 void print(std::ostream& os) override {
1728 os << "Scatter";
1729 switch (reduce_type_) {
1730 case Sum:
1731 os << " Sum";
1732 break;
1733 case Prod:
1734 os << " Prod";
1735 break;
1736 case Min:
1737 os << " Min";
1738 break;
1739 case Max:
1740 os << " Max";
1741 break;
1742 case None:
1743 break;
1744 }
1745 }
1746 bool is_equivalent(const Primitive& other) const override;
1747
1748 private:
1749 void eval(const std::vector<array>& inputs, array& out);
1750 ReduceType reduce_type_;
1751 std::vector<int> axes_;
1752};
1753
1754class Sigmoid : public UnaryPrimitive {
1755 public:
1756 explicit Sigmoid(Stream stream) : UnaryPrimitive(stream) {}
1757
1758 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1759 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1760
1766
1767 private:
1768 void eval(const std::vector<array>& inputs, array& out);
1769};
1770
1771class Sign : public UnaryPrimitive {
1772 public:
1773 explicit Sign(Stream stream) : UnaryPrimitive(stream) {}
1774
1775 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1776 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1777
1783
1784 private:
1785 void eval(const std::vector<array>& inputs, array& out);
1786};
1787
1788class Sin : public UnaryPrimitive {
1789 public:
1790 explicit Sin(Stream stream) : UnaryPrimitive(stream) {}
1791
1792 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1793 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1794
1800
1801 private:
1802 void eval(const std::vector<array>& inputs, array& out);
1803};
1804
1805class Sinh : public UnaryPrimitive {
1806 public:
1807 explicit Sinh(Stream stream) : UnaryPrimitive(stream) {}
1808
1809 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1810 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1811
1817
1818 private:
1819 void eval(const std::vector<array>& inputs, array& out);
1820};
1821
1822class Slice : public UnaryPrimitive {
1823 public:
1824 explicit Slice(
1825 Stream stream,
1826 const std::vector<int>& start_indices,
1827 const std::vector<int>& end_indices,
1828 const std::vector<int>& strides)
1829 : UnaryPrimitive(stream),
1830 start_indices_(start_indices),
1831 end_indices_(end_indices),
1832 strides_(strides) {}
1833
1834 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1835 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1836
1840 bool is_equivalent(const Primitive& other) const override;
1841
1842 private:
1843 std::vector<int> start_indices_;
1844 std::vector<int> end_indices_;
1845 std::vector<int> strides_;
1846
1847 void eval(const std::vector<array>& inputs, array& out);
1848};
1849
1851 public:
1852 explicit SliceUpdate(
1853 Stream stream,
1854 const std::vector<int>& start_indices,
1855 const std::vector<int>& end_indices,
1856 const std::vector<int>& strides)
1857 : UnaryPrimitive(stream),
1858 start_indices_(start_indices),
1859 end_indices_(end_indices),
1860 strides_(strides) {}
1861
1862 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1863 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1864
1868 bool is_equivalent(const Primitive& other) const override;
1869
1870 private:
1871 std::vector<int> start_indices_;
1872 std::vector<int> end_indices_;
1873 std::vector<int> strides_;
1874
1875 void eval(const std::vector<array>& inputs, array& out);
1876
1877 std::tuple<int64_t, std::vector<int64_t>> prepare_slice(const array& in);
1878};
1879
1880class Softmax : public UnaryPrimitive {
1881 public:
1882 explicit Softmax(Stream stream, bool precise)
1883 : UnaryPrimitive(stream), precise_(precise) {}
1884
1885 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1886 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1887
1892
1893 bool is_equivalent(const Primitive& other) const override;
1894
1895 private:
1896 void eval(const std::vector<array>& inputs, array& out);
1897 bool precise_;
1898};
1899
1900class Sort : public UnaryPrimitive {
1901 public:
1902 explicit Sort(Stream stream, int axis)
1903 : UnaryPrimitive(stream), axis_(axis) {}
1904
1905 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1906 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1907
1912 bool is_equivalent(const Primitive& other) const override;
1913
1914 private:
1915 int axis_;
1916
1917 void eval(const std::vector<array>& inputs, array& out);
1918};
1919
1920class Split : public Primitive {
1921 public:
1922 explicit Split(Stream stream, const std::vector<int>& indices, int axis)
1923 : Primitive(stream), indices_(indices), axis_(axis) {}
1924
1925 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
1926 override;
1927 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
1928 override;
1929
1933 bool is_equivalent(const Primitive& other) const override;
1934
1935 private:
1936 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
1937
1938 std::vector<int> indices_;
1939 int axis_;
1940};
1941
1942class Square : public UnaryPrimitive {
1943 public:
1944 explicit Square(Stream stream) : UnaryPrimitive(stream) {}
1945
1946 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1947 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1948
1954
1955 private:
1956 void eval(const std::vector<array>& inputs, array& out);
1957};
1958
1959class Sqrt : public UnaryPrimitive {
1960 public:
1961 explicit Sqrt(Stream stream, bool recip = false)
1962 : UnaryPrimitive(stream), recip_(recip) {}
1963
1964 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1965 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1966
1970 bool is_equivalent(const Primitive& other) const override;
1971
1972 void print(std::ostream& os) override {
1973 if (recip_) {
1974 os << "Rsqrt";
1975 } else {
1976 os << "Sqrt";
1977 }
1978 }
1979
1980 private:
1981 void eval(const std::vector<array>& inputs, array& out);
1982 bool recip_;
1983};
1984
1986 public:
1987 explicit StopGradient(Stream stream) : UnaryPrimitive(stream) {}
1988
1989 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1990 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1991
1996
1997 private:
1998 void eval(const std::vector<array>& inputs, array& out);
1999};
2000
2001class Subtract : public UnaryPrimitive {
2002 public:
2003 explicit Subtract(Stream stream) : UnaryPrimitive(stream) {}
2004
2005 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2006 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2007
2013
2014 private:
2015 void eval(const std::vector<array>& inputs, array& out);
2016};
2017
2018class Tan : public UnaryPrimitive {
2019 public:
2020 explicit Tan(Stream stream) : UnaryPrimitive(stream) {}
2021
2022 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2023 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2024
2030
2031 private:
2032 void eval(const std::vector<array>& inputs, array& out);
2033};
2034
2035class Tanh : public UnaryPrimitive {
2036 public:
2037 explicit Tanh(Stream stream) : UnaryPrimitive(stream) {}
2038
2039 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2040 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2041
2047
2048 private:
2049 void eval(const std::vector<array>& inputs, array& out);
2050};
2051
2052class Uniform : public UnaryPrimitive {
2053 public:
2054 explicit Uniform(Stream stream) : UnaryPrimitive(stream) {}
2055
2056 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2057 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2058
2062
2063 private:
2064 void eval(const std::vector<array>& inputs, array& out);
2065};
2066
2067class View : public UnaryPrimitive {
2068 public:
2069 explicit View(Stream stream, Dtype dtype)
2070 : UnaryPrimitive(stream), dtype_(dtype) {}
2071
2072 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2073 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2074
2076 void print(std::ostream& os) override;
2077 bool is_equivalent(const Primitive& other) const override;
2078
2079 private:
2080 Dtype dtype_;
2081};
2082
2084 public:
2085 explicit Transpose(Stream stream, const std::vector<int>& axes)
2086 : UnaryPrimitive(stream), axes_(axes) {}
2087
2088 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2089 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2090
2094 bool is_equivalent(const Primitive& other) const override;
2095
2096 private:
2097 std::vector<int> axes_;
2098
2099 void eval(const std::vector<array>& inputs, array& out);
2100};
2101
2102/* QR Factorization primitive. */
2103class QRF : public Primitive {
2104 public:
2105 explicit QRF(Stream stream) : Primitive(stream) {}
2106
2107 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2108 override;
2109 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2110 override;
2111
2113
2114 private:
2115 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
2116};
2117
2118/* SVD primitive. */
2119class SVD : public Primitive {
2120 public:
2121 explicit SVD(Stream stream) : Primitive(stream) {}
2122
2123 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2124 override;
2125 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2126 override;
2127
2130
2131 private:
2132 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
2133};
2134
2135/* Matrix inversion primitive. */
2136class Inverse : public UnaryPrimitive {
2137 public:
2138 explicit Inverse(Stream stream, bool tri, bool upper)
2139 : UnaryPrimitive(stream), tri_(tri), upper_(upper) {}
2140
2141 void eval_cpu(const std::vector<array>& inputs, array& output) override;
2142 void eval_gpu(const std::vector<array>& inputs, array& output) override;
2143
2146
2147 private:
2148 void eval(const std::vector<array>& inputs, array& output);
2149 bool tri_;
2150 bool upper_;
2151};
2152
2153class Cholesky : public UnaryPrimitive {
2154 public:
2155 explicit Cholesky(Stream stream, bool upper)
2156 : UnaryPrimitive(stream), upper_(upper) {}
2157
2158 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2159 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2160
2163
2164 private:
2165 void eval(const std::vector<array>& inputs, array& output);
2166 bool upper_;
2167};
2168
2169} // namespace mlx::core
Definition primitives.h:155
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Abs(Stream stream)
Definition primitives.h:157
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:164
std::vector< std::vector< int > > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:166
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:165
Definition primitives.h:172
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Add(Stream stream)
Definition primitives.h:174
Definition primitives.h:189
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
AddMM(Stream stream, float alpha, float beta)
Definition primitives.h:191
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
Definition primitives.h:213
Arange(Stream stream, double start, double stop, double step)
Definition primitives.h:215
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:232
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCos(Stream stream)
Definition primitives.h:234
Definition primitives.h:249
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCosh(Stream stream)
Definition primitives.h:251
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:266
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcSin(Stream stream)
Definition primitives.h:268
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:283
ArcSinh(Stream stream)
Definition primitives.h:285
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:317
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTan2(Stream stream)
Definition primitives.h:319
Definition primitives.h:300
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcTan(Stream stream)
Definition primitives.h:302
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:334
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTanh(Stream stream)
Definition primitives.h:336
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:351
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArgPartition(Stream stream, int kth, int axis)
Definition primitives.h:353
Definition primitives.h:371
ReduceType
Definition primitives.h:373
@ ArgMin
Definition primitives.h:374
@ ArgMax
Definition primitives.h:375
ArgReduce(Stream stream, ReduceType reduce_type, int axis)
Definition primitives.h:378
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:397
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArgSort(Stream stream, int axis)
Definition primitives.h:399
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:436
AsStrided(Stream stream, std::vector< int > shape, std::vector< size_t > strides, size_t offset)
Definition primitives.h:438
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:416
void eval_gpu(const std::vector< array > &inputs, array &out) override
AsType(Stream stream, Dtype dtype)
Definition primitives.h:418
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:463
BitwiseBinary(Stream stream, Op op)
Definition primitives.h:467
void eval_cpu(const std::vector< array > &inputs, array &out) override
Op
Definition primitives.h:465
@ And
Definition primitives.h:465
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:483
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
BlockMaskedMM(Stream stream, int block_size)
Definition primitives.h:485
Definition primitives.h:526
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Broadcast(Stream stream, const std::vector< int > &shape)
Definition primitives.h:528
Definition primitives.h:545
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Ceil(Stream stream)
Definition primitives.h:547
Definition primitives.h:2153
void eval_cpu(const std::vector< array > &inputs, array &out) override
Cholesky(Stream stream, bool upper)
Definition primitives.h:2155
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:562
Compiled(Stream stream, std::vector< array > inputs, std::vector< array > outputs, std::vector< array > tape, std::unordered_set< uintptr_t > constant_ids)
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:605
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Concatenate(Stream stream, int axis)
Definition primitives.h:607
Definition primitives.h:624
Conjugate(Stream stream)
Definition primitives.h:626
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:640
void eval_gpu(const std::vector< array > &inputs, array &out) override
Convolution(Stream stream, const std::vector< int > &kernel_strides, const std::vector< int > &padding, const std::vector< int > &kernel_dilation, const std::vector< int > &input_dilation, const int groups=1, const bool flip=false)
Definition primitives.h:642
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
Definition primitives.h:681
void eval_gpu(const std::vector< array > &inputs, array &out) override
Copy(Stream stream)
Definition primitives.h:683
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:698
void eval_cpu(const std::vector< array > &inputs, array &out) override
Cos(Stream stream)
Definition primitives.h:700
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:715
void eval_gpu(const std::vector< array > &inputs, array &out) override
Cosh(Stream stream)
Definition primitives.h:717
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:732
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
CustomTransforms(Stream stream, int num_outputs, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> vjp, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< int > &)> jvp, std::function< std::pair< std::vector< array >, std::vector< int > >(const std::vector< array > &, const std::vector< int > &)> vmap)
Definition primitives.h:734
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:784
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotan, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Depends(Stream stream)
Definition primitives.h:786
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:822
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
DivMod(Stream stream)
Definition primitives.h:824
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:805
Divide(Stream stream)
Definition primitives.h:807
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:878
Equal(Stream stream, bool equal_nan=false)
Definition primitives.h:880
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:904
Erf(Stream stream)
Definition primitives.h:906
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:921
void eval_gpu(const std::vector< array > &inputs, array &out) override
ErfInv(Stream stream)
Definition primitives.h:923
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:938
Exp(Stream stream)
Definition primitives.h:940
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:955
Expm1(Stream stream)
Definition primitives.h:957
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:971
FFT(Stream stream, const std::vector< size_t > &axes, bool inverse, bool real)
Definition primitives.h:973
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:997
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Floor(Stream stream)
Definition primitives.h:999
Definition primitives.h:1014
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Full(Stream stream)
Definition primitives.h:1016
Definition primitives.h:1030
Gather(Stream stream, const std::vector< int > &axes, const std::vector< int > &slice_sizes)
Definition primitives.h:1032
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:506
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, array &out) override
GatherMM(Stream stream)
Definition primitives.h:508
Definition primitives.h:1519
GatherQMM(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1521
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1069
void eval_cpu(const std::vector< array > &inputs, array &out) override
GreaterEqual(Stream stream)
Definition primitives.h:1071
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1052
Greater(Stream stream)
Definition primitives.h:1054
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1086
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Hadamard(Stream stream, float scale)
Definition primitives.h:1088
Definition primitives.h:2136
void eval_gpu(const std::vector< array > &inputs, array &output) override
Inverse(Stream stream, bool tri, bool upper)
Definition primitives.h:2138
void eval_cpu(const std::vector< array > &inputs, array &output) override
Definition primitives.h:1124
LessEqual(Stream stream)
Definition primitives.h:1126
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1107
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Less(Stream stream)
Definition primitives.h:1109
Definition primitives.h:1141
void eval_gpu(const std::vector< array > &inputs, array &out) override
Load(Stream stream, std::shared_ptr< io::Reader > reader, size_t offset, bool swap_endianness=false)
Definition primitives.h:1143
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1207
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Log1p(Stream stream)
Definition primitives.h:1209
Definition primitives.h:1274
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogAddExp(Stream stream)
Definition primitives.h:1276
Definition primitives.h:1173
Base
Definition primitives.h:1175
Log(Stream stream, Base base)
Definition primitives.h:1177
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1240
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalAnd(Stream stream)
Definition primitives.h:1242
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1223
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalNot(Stream stream)
Definition primitives.h:1225
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1257
void eval_cpu(const std::vector< array > &inputs, array &out) override
LogicalOr(Stream stream)
Definition primitives.h:1259
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1291
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, array &out) override
Matmul(Stream stream)
Definition primitives.h:1293
Definition primitives.h:1309
Maximum(Stream stream)
Definition primitives.h:1311
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1326
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Minimum(Stream stream)
Definition primitives.h:1328
Definition primitives.h:1343
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Multiply(Stream stream)
Definition primitives.h:1345
Definition primitives.h:1360
void eval_gpu(const std::vector< array > &inputs, array &out) override
Negative(Stream stream)
Definition primitives.h:1362
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1377
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
NotEqual(Stream stream)
Definition primitives.h:1379
Definition primitives.h:1394
void eval_gpu(const std::vector< array > &inputs, array &out) override
NumberOfElements(Stream stream, std::vector< int > axes, bool inverted, Dtype dtype)
Definition primitives.h:1396
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1425
void eval_cpu(const std::vector< array > &inputs, array &out) override
Pad(Stream stream, const std::vector< int > &axes, const std::vector< int > &low_pad_size, const std::vector< int > &high_pad_size)
Definition primitives.h:1427
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1453
void eval_cpu(const std::vector< array > &inputs, array &out) override
Partition(Stream stream, int kth, int axis)
Definition primitives.h:1455
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1474
void eval_cpu(const std::vector< array > &inputs, array &out) override
Power(Stream stream)
Definition primitives.h:1476
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:48
virtual void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
virtual std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
The vector-Jacobian product.
virtual ~Primitive()=default
Primitive(const Primitive &other)=delete
Primitive(Primitive &&other)=delete
const Stream & stream()
The stream the primitive will run on.
Definition primitives.h:58
Primitive & operator=(Primitive &&other)=delete
virtual bool is_equivalent(const Primitive &other) const
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:107
Primitive & operator=(const Primitive &other)=delete
virtual std::vector< std::vector< int > > output_shapes(const std::vector< array > &inputs)
Get the output shapes of the primitive.
const Device & device()
The device the primitive will run on.
Definition primitives.h:53
virtual std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
The Jacobian-vector product.
virtual std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes)
The primitive must know how to vectorize itself across the given axes.
virtual void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
virtual void print(std::ostream &os)=0
Print the primitive.
Primitive(Stream stream)
Definition primitives.h:50
Definition primitives.h:2103
QRF(Stream stream)
Definition primitives.h:2105
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:1491
void eval_gpu(const std::vector< array > &inputs, array &out) override
QuantizedMatmul(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1493
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1543
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
RandomBits(Stream stream, const std::vector< int > &shape, int width)
Definition primitives.h:1545
Definition primitives.h:1589
Reduce(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1593
ReduceType
Definition primitives.h:1591
@ And
Definition primitives.h:1591
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:861
Remainder(Stream stream)
Definition primitives.h:863
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1562
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Reshape(Stream stream, const std::vector< int > &shape)
Definition primitives.h:1564
Definition primitives.h:1644
Round(Stream stream)
Definition primitives.h:1646
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2119
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
SVD(Stream stream)
Definition primitives.h:2121
Definition primitives.h:1661
void eval_cpu(const std::vector< array > &inputs, array &out) override
ReduceType
Definition primitives.h:1663
@ Max
Definition primitives.h:1663
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Scan(Stream stream, ReduceType reduce_type, int axis, bool reverse, bool inclusive)
Definition primitives.h:1665
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1711
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
ReduceType
Definition primitives.h:1713
@ Max
Definition primitives.h:1713
void eval_cpu(const std::vector< array > &inputs, array &out) override
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1727
void eval_gpu(const std::vector< array > &inputs, array &out) override
Scatter(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1715
Definition primitives.h:844
void eval_gpu(const std::vector< array > &inputs, array &out) override
Select(Stream stream)
Definition primitives.h:846
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1754
Sigmoid(Stream stream)
Definition primitives.h:1756
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1771
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Sign(Stream stream)
Definition primitives.h:1773
Definition primitives.h:1788
Sin(Stream stream)
Definition primitives.h:1790
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1805
Sinh(Stream stream)
Definition primitives.h:1807
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1822
void eval_cpu(const std::vector< array > &inputs, array &out) override
Slice(Stream stream, const std::vector< int > &start_indices, const std::vector< int > &end_indices, const std::vector< int > &strides)
Definition primitives.h:1824
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1850
SliceUpdate(Stream stream, const std::vector< int > &start_indices, const std::vector< int > &end_indices, const std::vector< int > &strides)
Definition primitives.h:1852
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1880
void eval_gpu(const std::vector< array > &inputs, array &out) override
Softmax(Stream stream, bool precise)
Definition primitives.h:1882
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1900
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sort(Stream stream, int axis)
Definition primitives.h:1902
Definition primitives.h:1920
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Split(Stream stream, const std::vector< int > &indices, int axis)
Definition primitives.h:1922
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:1959
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sqrt(Stream stream, bool recip=false)
Definition primitives.h:1961
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1942
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Square(Stream stream)
Definition primitives.h:1944
Definition primitives.h:1985
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
StopGradient(Stream stream)
Definition primitives.h:1987
Definition primitives.h:2001
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Subtract(Stream stream)
Definition primitives.h:2003
Definition primitives.h:2018
Tan(Stream stream)
Definition primitives.h:2020
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2035
void eval_gpu(const std::vector< array > &inputs, array &out) override
Tanh(Stream stream)
Definition primitives.h:2037
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2083
Transpose(Stream stream, const std::vector< int > &axes)
Definition primitives.h:2085
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:127
UnaryPrimitive & operator=(const UnaryPrimitive &other)=delete
UnaryPrimitive(Stream stream)
An abstract base class for a primitive with a single output.
Definition primitives.h:132
virtual void eval_gpu(const std::vector< array > &inputs, array &output)=0
UnaryPrimitive(UnaryPrimitive &&other)=delete
virtual void eval_cpu(const std::vector< array > &inputs, array &output)=0
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:142
UnaryPrimitive(const UnaryPrimitive &other)=delete
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:137
UnaryPrimitive & operator=(UnaryPrimitive &&other)=delete
virtual ~UnaryPrimitive()=default
Definition primitives.h:2052
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Uniform(Stream stream)
Definition primitives.h:2054
Definition primitives.h:2067
void eval_cpu(const std::vector< array > &inputs, array &out) override
View(Stream stream, Dtype dtype)
Definition primitives.h:2069
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition array.h:20
Op op
Definition binary.h:141
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
array tri(int n, int m, int k, Dtype type, StreamOrDevice s={})
array transpose(const array &a, std::vector< int > axes, StreamOrDevice s={})
Permutes the dimensions according to the given axes.
Definition allocator.h:7
std::pair< std::vector< array >, std::vector< array > > jvp(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &tangents)
Computes the output and Jacobian-vector product (JVP) of a function.
std::pair< std::vector< array >, std::vector< array > > vjp(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &cotangents)
Computes the output and vector-Jacobian product (VJP) of a function.
void eval(std::vector< array > outputs)
std::function< array(const array &) vmap)(const std::function< array(const array &)> &fun, int in_axis=0, int out_axis=0)
Automatically vectorize a unary function over the requested axes.
#define DEFINE_DEFAULT_IS_EQUIVALENT()
Definition primitives.h:34
#define DEFINE_PRINT(PRIMITIVE)
Definition primitives.h:29
#define DEFINE_INPUT_OUTPUT_SHAPE()
Definition primitives.h:39
#define DEFINE_GRADS()
Definition primitives.h:17
#define DEFINE_VMAP()
Definition primitives.h:12
Definition ops.h:37
Definition binary_ops.h:270
Definition ops.h:185
Definition ops.h:163
Definition ops.h:29
Definition ops.h:78
Definition ops.h:141
Definition binary_ops.h:277
Definition ops.h:119
Definition device.h:7
Definition dtype.h:13
Definition stream.h:9
Device device
Definition stream.h:11