mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:11:43 +08:00
docs update
This commit is contained in:

committed by
CircleCI Docs

parent
85f70be0e6
commit
0ec311dff3
@@ -118,104 +118,115 @@ $(function() { codefold.init(0); });
|
||||
<div class="line"><a id="l00030" name="l00030"></a><span class="lineno"> 30</span>}</div>
|
||||
</div>
|
||||
<div class="line"><a id="l00031" name="l00031"></a><span class="lineno"> 31</span> </div>
|
||||
<div class="line"><a id="l00032" name="l00032"></a><span class="lineno"> 32</span><span class="comment">// Collapse dims that are contiguous to possibly route to a better kernel</span></div>
|
||||
<div class="line"><a id="l00033" name="l00033"></a><span class="lineno"> 33</span><span class="comment">// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})</span></div>
|
||||
<div class="line"><a id="l00034" name="l00034"></a><span class="lineno"> 34</span><span class="comment">// should return {{2, 4}, {{1, 2}}}.</span></div>
|
||||
<div class="line"><a id="l00035" name="l00035"></a><span class="lineno"> 35</span><span class="comment">//</span></div>
|
||||
<div class="line"><a id="l00036" name="l00036"></a><span class="lineno"> 36</span><span class="comment">// When multiple arrays are passed they should all have the same shape. The</span></div>
|
||||
<div class="line"><a id="l00037" name="l00037"></a><span class="lineno"> 37</span><span class="comment">// collapsed axes are also the same so one shape is returned.</span></div>
|
||||
<div class="line"><a id="l00038" name="l00038"></a><span class="lineno"> 38</span><span class="keyword">template</span> <<span class="keyword">typename</span> str<span class="keywordtype">id</span>e_t></div>
|
||||
<div class="line"><a id="l00039" name="l00039"></a><span class="lineno"> 39</span><span class="keyword">inline</span> std::tuple<std::vector<int>, std::vector<std::vector<stride_t>>></div>
|
||||
<div class="foldopen" id="foldopen00040" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00040" name="l00040"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a9d151ba3e138be1954d2f51f85806b0c"> 40</a></span><a class="code hl_function" href="namespacemlx_1_1core.html#a9d151ba3e138be1954d2f51f85806b0c">collapse_contiguous_dims</a>(</div>
|
||||
<div class="line"><a id="l00041" name="l00041"></a><span class="lineno"> 41</span> <span class="keyword">const</span> std::vector<int>& shape,</div>
|
||||
<div class="line"><a id="l00042" name="l00042"></a><span class="lineno"> 42</span> <span class="keyword">const</span> std::vector<std::vector<stride_t>> strides) {</div>
|
||||
<div class="line"><a id="l00043" name="l00043"></a><span class="lineno"> 43</span> <span class="comment">// Make a vector that has axes separated with -1. Collapse all axes between</span></div>
|
||||
<div class="line"><a id="l00044" name="l00044"></a><span class="lineno"> 44</span> <span class="comment">// -1.</span></div>
|
||||
<div class="line"><a id="l00045" name="l00045"></a><span class="lineno"> 45</span> std::vector<int> to_collapse;</div>
|
||||
<div class="line"><a id="l00046" name="l00046"></a><span class="lineno"> 46</span> <span class="keywordflow">if</span> (shape.size() > 0) {</div>
|
||||
<div class="line"><a id="l00047" name="l00047"></a><span class="lineno"> 47</span> to_collapse.push_back(0);</div>
|
||||
<div class="line"><a id="l00048" name="l00048"></a><span class="lineno"> 48</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 1; i < shape.size(); i++) {</div>
|
||||
<div class="line"><a id="l00049" name="l00049"></a><span class="lineno"> 49</span> <span class="keywordtype">bool</span> contiguous = <span class="keyword">true</span>;</div>
|
||||
<div class="line"><a id="l00050" name="l00050"></a><span class="lineno"> 50</span> <span class="keywordflow">for</span> (<span class="keyword">const</span> std::vector<stride_t>& st : strides) {</div>
|
||||
<div class="line"><a id="l00051" name="l00051"></a><span class="lineno"> 51</span> <span class="keywordflow">if</span> (st[i] * shape[i] != st[i - 1]) {</div>
|
||||
<div class="line"><a id="l00052" name="l00052"></a><span class="lineno"> 52</span> contiguous = <span class="keyword">false</span>;</div>
|
||||
<div class="line"><a id="l00053" name="l00053"></a><span class="lineno"> 53</span> }</div>
|
||||
<div class="line"><a id="l00054" name="l00054"></a><span class="lineno"> 54</span> <span class="keywordflow">if</span> (!contiguous) {</div>
|
||||
<div class="line"><a id="l00055" name="l00055"></a><span class="lineno"> 55</span> <span class="keywordflow">break</span>;</div>
|
||||
<div class="line"><a id="l00056" name="l00056"></a><span class="lineno"> 56</span> }</div>
|
||||
<div class="line"><a id="l00057" name="l00057"></a><span class="lineno"> 57</span> }</div>
|
||||
<div class="line"><a id="l00058" name="l00058"></a><span class="lineno"> 58</span> <span class="keywordflow">if</span> (!contiguous) {</div>
|
||||
<div class="line"><a id="l00059" name="l00059"></a><span class="lineno"> 59</span> to_collapse.push_back(-1);</div>
|
||||
<div class="line"><a id="l00060" name="l00060"></a><span class="lineno"> 60</span> }</div>
|
||||
<div class="line"><a id="l00061" name="l00061"></a><span class="lineno"> 61</span> to_collapse.push_back(i);</div>
|
||||
<div class="line"><a id="l00062" name="l00062"></a><span class="lineno"> 62</span> }</div>
|
||||
<div class="line"><a id="l00063" name="l00063"></a><span class="lineno"> 63</span> to_collapse.push_back(-1);</div>
|
||||
<div class="line"><a id="l00064" name="l00064"></a><span class="lineno"> 64</span> }</div>
|
||||
<div class="line"><a id="l00065" name="l00065"></a><span class="lineno"> 65</span> </div>
|
||||
<div class="line"><a id="l00066" name="l00066"></a><span class="lineno"> 66</span> std::vector<int> out_shape;</div>
|
||||
<div class="line"><a id="l00067" name="l00067"></a><span class="lineno"> 67</span> std::vector<std::vector<stride_t>> out_strides(strides.size());</div>
|
||||
<div class="line"><a id="l00068" name="l00068"></a><span class="lineno"> 68</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < to_collapse.size(); i++) {</div>
|
||||
<div class="line"><a id="l00069" name="l00069"></a><span class="lineno"> 69</span> <span class="keywordtype">int</span> current_shape = shape[to_collapse[i]];</div>
|
||||
<div class="line"><a id="l00070" name="l00070"></a><span class="lineno"> 70</span> <span class="keywordflow">while</span> (to_collapse[++i] != -1) {</div>
|
||||
<div class="line"><a id="l00071" name="l00071"></a><span class="lineno"> 71</span> current_shape *= shape[to_collapse[i]];</div>
|
||||
<div class="line"><a id="l00072" name="l00072"></a><span class="lineno"> 72</span> }</div>
|
||||
<div class="line"><a id="l00073" name="l00073"></a><span class="lineno"> 73</span> out_shape.push_back(current_shape);</div>
|
||||
<div class="line"><a id="l00074" name="l00074"></a><span class="lineno"> 74</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> j = 0; j < strides.size(); j++) {</div>
|
||||
<div class="line"><a id="l00075" name="l00075"></a><span class="lineno"> 75</span> <span class="keyword">const</span> std::vector<stride_t>& st = strides[j];</div>
|
||||
<div class="line"><a id="l00076" name="l00076"></a><span class="lineno"> 76</span> out_strides[j].push_back(st[to_collapse[i - 1]]);</div>
|
||||
<div class="line"><a id="l00077" name="l00077"></a><span class="lineno"> 77</span> }</div>
|
||||
<div class="line"><a id="l00078" name="l00078"></a><span class="lineno"> 78</span> }</div>
|
||||
<div class="line"><a id="l00079" name="l00079"></a><span class="lineno"> 79</span> </div>
|
||||
<div class="line"><a id="l00080" name="l00080"></a><span class="lineno"> 80</span> <span class="keywordflow">return</span> std::make_tuple(out_shape, out_strides);</div>
|
||||
<div class="line"><a id="l00081" name="l00081"></a><span class="lineno"> 81</span>}</div>
|
||||
<div class="line"><a id="l00032" name="l00032"></a><span class="lineno"> 32</span><span class="keyword">template</span> <<span class="keyword">typename</span> str<span class="keywordtype">id</span>e_t></div>
|
||||
<div class="foldopen" id="foldopen00033" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00033" name="l00033"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#ac9fb1286a1a00395e901dbff80560895"> 33</a></span>std::vector<stride_t> <a class="code hl_function" href="namespacemlx_1_1core.html#ac9fb1286a1a00395e901dbff80560895">make_contiguous_strides</a>(<span class="keyword">const</span> std::vector<int>& shape) {</div>
|
||||
<div class="line"><a id="l00034" name="l00034"></a><span class="lineno"> 34</span> std::vector<stride_t> strides(shape.size(), 1);</div>
|
||||
<div class="line"><a id="l00035" name="l00035"></a><span class="lineno"> 35</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = shape.size() - 1; i > 0; i--) {</div>
|
||||
<div class="line"><a id="l00036" name="l00036"></a><span class="lineno"> 36</span> strides[i - 1] = strides[i] * shape[i];</div>
|
||||
<div class="line"><a id="l00037" name="l00037"></a><span class="lineno"> 37</span> }</div>
|
||||
<div class="line"><a id="l00038" name="l00038"></a><span class="lineno"> 38</span> <span class="keywordflow">return</span> strides;</div>
|
||||
<div class="line"><a id="l00039" name="l00039"></a><span class="lineno"> 39</span>}</div>
|
||||
</div>
|
||||
<div class="line"><a id="l00082" name="l00082"></a><span class="lineno"> 82</span> </div>
|
||||
<div class="line"><a id="l00083" name="l00083"></a><span class="lineno"> 83</span><span class="keyword">inline</span> std::tuple<std::vector<int>, std::vector<std::vector<size_t>>></div>
|
||||
<div class="foldopen" id="foldopen00084" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00084" name="l00084"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a8430e0baac3f6d8a2ab22428f9c0b7e2"> 84</a></span><a class="code hl_function" href="namespacemlx_1_1core.html#a9d151ba3e138be1954d2f51f85806b0c">collapse_contiguous_dims</a>(<span class="keyword">const</span> std::vector<array>& xs) {</div>
|
||||
<div class="line"><a id="l00085" name="l00085"></a><span class="lineno"> 85</span> std::vector<std::vector<size_t>> strides;</div>
|
||||
<div class="line"><a id="l00086" name="l00086"></a><span class="lineno"> 86</span> <span class="keywordflow">for</span> (<span class="keyword">auto</span>& x : xs) {</div>
|
||||
<div class="line"><a id="l00087" name="l00087"></a><span class="lineno"> 87</span> strides.emplace_back(x.strides());</div>
|
||||
<div class="line"><a id="l00088" name="l00088"></a><span class="lineno"> 88</span> }</div>
|
||||
<div class="line"><a id="l00089" name="l00089"></a><span class="lineno"> 89</span> <span class="keywordflow">return</span> <a class="code hl_function" href="namespacemlx_1_1core.html#a9d151ba3e138be1954d2f51f85806b0c">collapse_contiguous_dims</a>(xs[0].shape(), strides);</div>
|
||||
<div class="line"><a id="l00040" name="l00040"></a><span class="lineno"> 40</span> </div>
|
||||
<div class="line"><a id="l00041" name="l00041"></a><span class="lineno"> 41</span><span class="comment">// Collapse dims that are contiguous to possibly route to a better kernel</span></div>
|
||||
<div class="line"><a id="l00042" name="l00042"></a><span class="lineno"> 42</span><span class="comment">// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})</span></div>
|
||||
<div class="line"><a id="l00043" name="l00043"></a><span class="lineno"> 43</span><span class="comment">// should return {{2, 4}, {{1, 2}}}.</span></div>
|
||||
<div class="line"><a id="l00044" name="l00044"></a><span class="lineno"> 44</span><span class="comment">//</span></div>
|
||||
<div class="line"><a id="l00045" name="l00045"></a><span class="lineno"> 45</span><span class="comment">// When multiple arrays are passed they should all have the same shape. The</span></div>
|
||||
<div class="line"><a id="l00046" name="l00046"></a><span class="lineno"> 46</span><span class="comment">// collapsed axes are also the same so one shape is returned.</span></div>
|
||||
<div class="line"><a id="l00047" name="l00047"></a><span class="lineno"> 47</span><span class="keyword">template</span> <<span class="keyword">typename</span> str<span class="keywordtype">id</span>e_t></div>
|
||||
<div class="line"><a id="l00048" name="l00048"></a><span class="lineno"> 48</span><span class="keyword">inline</span> std::tuple<std::vector<int>, std::vector<std::vector<stride_t>>></div>
|
||||
<div class="foldopen" id="foldopen00049" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00049" name="l00049"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a9d151ba3e138be1954d2f51f85806b0c"> 49</a></span><a class="code hl_function" href="namespacemlx_1_1core.html#a9d151ba3e138be1954d2f51f85806b0c">collapse_contiguous_dims</a>(</div>
|
||||
<div class="line"><a id="l00050" name="l00050"></a><span class="lineno"> 50</span> <span class="keyword">const</span> std::vector<int>& shape,</div>
|
||||
<div class="line"><a id="l00051" name="l00051"></a><span class="lineno"> 51</span> <span class="keyword">const</span> std::vector<std::vector<stride_t>> strides) {</div>
|
||||
<div class="line"><a id="l00052" name="l00052"></a><span class="lineno"> 52</span> <span class="comment">// Make a vector that has axes separated with -1. Collapse all axes between</span></div>
|
||||
<div class="line"><a id="l00053" name="l00053"></a><span class="lineno"> 53</span> <span class="comment">// -1.</span></div>
|
||||
<div class="line"><a id="l00054" name="l00054"></a><span class="lineno"> 54</span> std::vector<int> to_collapse;</div>
|
||||
<div class="line"><a id="l00055" name="l00055"></a><span class="lineno"> 55</span> <span class="keywordflow">if</span> (shape.size() > 0) {</div>
|
||||
<div class="line"><a id="l00056" name="l00056"></a><span class="lineno"> 56</span> to_collapse.push_back(0);</div>
|
||||
<div class="line"><a id="l00057" name="l00057"></a><span class="lineno"> 57</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 1; i < shape.size(); i++) {</div>
|
||||
<div class="line"><a id="l00058" name="l00058"></a><span class="lineno"> 58</span> <span class="keywordtype">bool</span> contiguous = <span class="keyword">true</span>;</div>
|
||||
<div class="line"><a id="l00059" name="l00059"></a><span class="lineno"> 59</span> <span class="keywordflow">for</span> (<span class="keyword">const</span> std::vector<stride_t>& st : strides) {</div>
|
||||
<div class="line"><a id="l00060" name="l00060"></a><span class="lineno"> 60</span> <span class="keywordflow">if</span> (st[i] * shape[i] != st[i - 1]) {</div>
|
||||
<div class="line"><a id="l00061" name="l00061"></a><span class="lineno"> 61</span> contiguous = <span class="keyword">false</span>;</div>
|
||||
<div class="line"><a id="l00062" name="l00062"></a><span class="lineno"> 62</span> }</div>
|
||||
<div class="line"><a id="l00063" name="l00063"></a><span class="lineno"> 63</span> <span class="keywordflow">if</span> (!contiguous) {</div>
|
||||
<div class="line"><a id="l00064" name="l00064"></a><span class="lineno"> 64</span> <span class="keywordflow">break</span>;</div>
|
||||
<div class="line"><a id="l00065" name="l00065"></a><span class="lineno"> 65</span> }</div>
|
||||
<div class="line"><a id="l00066" name="l00066"></a><span class="lineno"> 66</span> }</div>
|
||||
<div class="line"><a id="l00067" name="l00067"></a><span class="lineno"> 67</span> <span class="keywordflow">if</span> (!contiguous) {</div>
|
||||
<div class="line"><a id="l00068" name="l00068"></a><span class="lineno"> 68</span> to_collapse.push_back(-1);</div>
|
||||
<div class="line"><a id="l00069" name="l00069"></a><span class="lineno"> 69</span> }</div>
|
||||
<div class="line"><a id="l00070" name="l00070"></a><span class="lineno"> 70</span> to_collapse.push_back(i);</div>
|
||||
<div class="line"><a id="l00071" name="l00071"></a><span class="lineno"> 71</span> }</div>
|
||||
<div class="line"><a id="l00072" name="l00072"></a><span class="lineno"> 72</span> to_collapse.push_back(-1);</div>
|
||||
<div class="line"><a id="l00073" name="l00073"></a><span class="lineno"> 73</span> }</div>
|
||||
<div class="line"><a id="l00074" name="l00074"></a><span class="lineno"> 74</span> </div>
|
||||
<div class="line"><a id="l00075" name="l00075"></a><span class="lineno"> 75</span> std::vector<int> out_shape;</div>
|
||||
<div class="line"><a id="l00076" name="l00076"></a><span class="lineno"> 76</span> std::vector<std::vector<stride_t>> out_strides(strides.size());</div>
|
||||
<div class="line"><a id="l00077" name="l00077"></a><span class="lineno"> 77</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < to_collapse.size(); i++) {</div>
|
||||
<div class="line"><a id="l00078" name="l00078"></a><span class="lineno"> 78</span> <span class="keywordtype">int</span> current_shape = shape[to_collapse[i]];</div>
|
||||
<div class="line"><a id="l00079" name="l00079"></a><span class="lineno"> 79</span> <span class="keywordflow">while</span> (to_collapse[++i] != -1) {</div>
|
||||
<div class="line"><a id="l00080" name="l00080"></a><span class="lineno"> 80</span> current_shape *= shape[to_collapse[i]];</div>
|
||||
<div class="line"><a id="l00081" name="l00081"></a><span class="lineno"> 81</span> }</div>
|
||||
<div class="line"><a id="l00082" name="l00082"></a><span class="lineno"> 82</span> out_shape.push_back(current_shape);</div>
|
||||
<div class="line"><a id="l00083" name="l00083"></a><span class="lineno"> 83</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> j = 0; j < strides.size(); j++) {</div>
|
||||
<div class="line"><a id="l00084" name="l00084"></a><span class="lineno"> 84</span> <span class="keyword">const</span> std::vector<stride_t>& st = strides[j];</div>
|
||||
<div class="line"><a id="l00085" name="l00085"></a><span class="lineno"> 85</span> out_strides[j].push_back(st[to_collapse[i - 1]]);</div>
|
||||
<div class="line"><a id="l00086" name="l00086"></a><span class="lineno"> 86</span> }</div>
|
||||
<div class="line"><a id="l00087" name="l00087"></a><span class="lineno"> 87</span> }</div>
|
||||
<div class="line"><a id="l00088" name="l00088"></a><span class="lineno"> 88</span> </div>
|
||||
<div class="line"><a id="l00089" name="l00089"></a><span class="lineno"> 89</span> <span class="keywordflow">return</span> std::make_tuple(out_shape, out_strides);</div>
|
||||
<div class="line"><a id="l00090" name="l00090"></a><span class="lineno"> 90</span>}</div>
|
||||
</div>
|
||||
<div class="line"><a id="l00091" name="l00091"></a><span class="lineno"> 91</span> </div>
|
||||
<div class="line"><a id="l00092" name="l00092"></a><span class="lineno"> 92</span><span class="keyword">template</span> <<span class="keyword">typename</span>... Arrays, <span class="keyword">typename</span> = <a class="code hl_typedef" href="namespacemlx_1_1core.html#af89751d79339f3e4d9318ea97d64d114">enable_for_arrays_t</a><Arrays...>></div>
|
||||
<div class="line"><a id="l00092" name="l00092"></a><span class="lineno"> 92</span><span class="keyword">inline</span> std::tuple<std::vector<int>, std::vector<std::vector<size_t>>></div>
|
||||
<div class="foldopen" id="foldopen00093" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00093" name="l00093"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#ac813412cce77fc1340dcfefc6e099276"> 93</a></span><span class="keyword">inline</span> <span class="keyword">auto</span> <a class="code hl_function" href="namespacemlx_1_1core.html#a9d151ba3e138be1954d2f51f85806b0c">collapse_contiguous_dims</a>(Arrays&&... xs) {</div>
|
||||
<div class="line"><a id="l00094" name="l00094"></a><span class="lineno"> 94</span> <span class="keywordflow">return</span> <a class="code hl_function" href="namespacemlx_1_1core.html#a9d151ba3e138be1954d2f51f85806b0c">collapse_contiguous_dims</a>(</div>
|
||||
<div class="line"><a id="l00095" name="l00095"></a><span class="lineno"> 95</span> std::vector<array>{std::forward<Arrays>(xs)...});</div>
|
||||
<div class="line"><a id="l00096" name="l00096"></a><span class="lineno"> 96</span>}</div>
|
||||
<div class="line"><a id="l00093" name="l00093"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a8430e0baac3f6d8a2ab22428f9c0b7e2"> 93</a></span><a class="code hl_function" href="namespacemlx_1_1core.html#a9d151ba3e138be1954d2f51f85806b0c">collapse_contiguous_dims</a>(<span class="keyword">const</span> std::vector<array>& xs) {</div>
|
||||
<div class="line"><a id="l00094" name="l00094"></a><span class="lineno"> 94</span> std::vector<std::vector<size_t>> strides;</div>
|
||||
<div class="line"><a id="l00095" name="l00095"></a><span class="lineno"> 95</span> <span class="keywordflow">for</span> (<span class="keyword">auto</span>& x : xs) {</div>
|
||||
<div class="line"><a id="l00096" name="l00096"></a><span class="lineno"> 96</span> strides.emplace_back(x.strides());</div>
|
||||
<div class="line"><a id="l00097" name="l00097"></a><span class="lineno"> 97</span> }</div>
|
||||
<div class="line"><a id="l00098" name="l00098"></a><span class="lineno"> 98</span> <span class="keywordflow">return</span> <a class="code hl_function" href="namespacemlx_1_1core.html#a9d151ba3e138be1954d2f51f85806b0c">collapse_contiguous_dims</a>(xs[0].shape(), strides);</div>
|
||||
<div class="line"><a id="l00099" name="l00099"></a><span class="lineno"> 99</span>}</div>
|
||||
</div>
|
||||
<div class="line"><a id="l00097" name="l00097"></a><span class="lineno"> 97</span> </div>
|
||||
<div class="line"><a id="l00098" name="l00098"></a><span class="lineno"> 98</span><span class="keyword">template</span> <<span class="keyword">typename</span> str<span class="keywordtype">id</span>e_t></div>
|
||||
<div class="foldopen" id="foldopen00099" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00099" name="l00099"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a847b0a276663d9ddb5cac905ee977f03"> 99</a></span><span class="keyword">inline</span> <span class="keyword">auto</span> <a class="code hl_function" href="namespacemlx_1_1core.html#a847b0a276663d9ddb5cac905ee977f03">check_contiguity</a>(</div>
|
||||
<div class="line"><a id="l00100" name="l00100"></a><span class="lineno"> 100</span> <span class="keyword">const</span> std::vector<int>& shape,</div>
|
||||
<div class="line"><a id="l00101" name="l00101"></a><span class="lineno"> 101</span> <span class="keyword">const</span> std::vector<stride_t>& strides) {</div>
|
||||
<div class="line"><a id="l00102" name="l00102"></a><span class="lineno"> 102</span> <span class="keywordtype">size_t</span> data_size = 1;</div>
|
||||
<div class="line"><a id="l00103" name="l00103"></a><span class="lineno"> 103</span> <span class="keywordtype">size_t</span> f_stride = 1;</div>
|
||||
<div class="line"><a id="l00104" name="l00104"></a><span class="lineno"> 104</span> <span class="keywordtype">size_t</span> b_stride = 1;</div>
|
||||
<div class="line"><a id="l00105" name="l00105"></a><span class="lineno"> 105</span> <span class="keywordtype">bool</span> is_row_contiguous = <span class="keyword">true</span>;</div>
|
||||
<div class="line"><a id="l00106" name="l00106"></a><span class="lineno"> 106</span> <span class="keywordtype">bool</span> is_col_contiguous = <span class="keyword">true</span>;</div>
|
||||
<div class="line"><a id="l00107" name="l00107"></a><span class="lineno"> 107</span> </div>
|
||||
<div class="line"><a id="l00108" name="l00108"></a><span class="lineno"> 108</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {</div>
|
||||
<div class="line"><a id="l00109" name="l00109"></a><span class="lineno"> 109</span> is_row_contiguous &= strides[i] == f_stride || shape[i] == 1;</div>
|
||||
<div class="line"><a id="l00110" name="l00110"></a><span class="lineno"> 110</span> is_col_contiguous &= strides[ri] == b_stride || shape[ri] == 1;</div>
|
||||
<div class="line"><a id="l00111" name="l00111"></a><span class="lineno"> 111</span> f_stride *= shape[i];</div>
|
||||
<div class="line"><a id="l00112" name="l00112"></a><span class="lineno"> 112</span> b_stride *= shape[ri];</div>
|
||||
<div class="line"><a id="l00113" name="l00113"></a><span class="lineno"> 113</span> <span class="keywordflow">if</span> (strides[i] > 0) {</div>
|
||||
<div class="line"><a id="l00114" name="l00114"></a><span class="lineno"> 114</span> data_size *= shape[i];</div>
|
||||
<div class="line"><a id="l00115" name="l00115"></a><span class="lineno"> 115</span> }</div>
|
||||
<div class="line"><a id="l00116" name="l00116"></a><span class="lineno"> 116</span> }</div>
|
||||
<div class="line"><a id="l00117" name="l00117"></a><span class="lineno"> 117</span> </div>
|
||||
<div class="line"><a id="l00118" name="l00118"></a><span class="lineno"> 118</span> <span class="keywordflow">return</span> std::make_tuple(data_size, is_row_contiguous, is_col_contiguous);</div>
|
||||
<div class="line"><a id="l00119" name="l00119"></a><span class="lineno"> 119</span>}</div>
|
||||
<div class="line"><a id="l00100" name="l00100"></a><span class="lineno"> 100</span> </div>
|
||||
<div class="line"><a id="l00101" name="l00101"></a><span class="lineno"> 101</span><span class="keyword">template</span> <<span class="keyword">typename</span>... Arrays, <span class="keyword">typename</span> = <a class="code hl_typedef" href="namespacemlx_1_1core.html#af89751d79339f3e4d9318ea97d64d114">enable_for_arrays_t</a><Arrays...>></div>
|
||||
<div class="foldopen" id="foldopen00102" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00102" name="l00102"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#ac813412cce77fc1340dcfefc6e099276"> 102</a></span><span class="keyword">inline</span> <span class="keyword">auto</span> <a class="code hl_function" href="namespacemlx_1_1core.html#a9d151ba3e138be1954d2f51f85806b0c">collapse_contiguous_dims</a>(Arrays&&... xs) {</div>
|
||||
<div class="line"><a id="l00103" name="l00103"></a><span class="lineno"> 103</span> <span class="keywordflow">return</span> <a class="code hl_function" href="namespacemlx_1_1core.html#a9d151ba3e138be1954d2f51f85806b0c">collapse_contiguous_dims</a>(</div>
|
||||
<div class="line"><a id="l00104" name="l00104"></a><span class="lineno"> 104</span> std::vector<array>{std::forward<Arrays>(xs)...});</div>
|
||||
<div class="line"><a id="l00105" name="l00105"></a><span class="lineno"> 105</span>}</div>
|
||||
</div>
|
||||
<div class="line"><a id="l00120" name="l00120"></a><span class="lineno"> 120</span> </div>
|
||||
<div class="line"><a id="l00121" name="l00121"></a><span class="lineno"> 121</span>} <span class="comment">// namespace mlx::core</span></div>
|
||||
<div class="line"><a id="l00106" name="l00106"></a><span class="lineno"> 106</span> </div>
|
||||
<div class="line"><a id="l00107" name="l00107"></a><span class="lineno"> 107</span><span class="keyword">template</span> <<span class="keyword">typename</span> str<span class="keywordtype">id</span>e_t></div>
|
||||
<div class="foldopen" id="foldopen00108" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00108" name="l00108"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a847b0a276663d9ddb5cac905ee977f03"> 108</a></span><span class="keyword">inline</span> <span class="keyword">auto</span> <a class="code hl_function" href="namespacemlx_1_1core.html#a847b0a276663d9ddb5cac905ee977f03">check_contiguity</a>(</div>
|
||||
<div class="line"><a id="l00109" name="l00109"></a><span class="lineno"> 109</span> <span class="keyword">const</span> std::vector<int>& shape,</div>
|
||||
<div class="line"><a id="l00110" name="l00110"></a><span class="lineno"> 110</span> <span class="keyword">const</span> std::vector<stride_t>& strides) {</div>
|
||||
<div class="line"><a id="l00111" name="l00111"></a><span class="lineno"> 111</span> <span class="keywordtype">size_t</span> data_size = 1;</div>
|
||||
<div class="line"><a id="l00112" name="l00112"></a><span class="lineno"> 112</span> <span class="keywordtype">size_t</span> f_stride = 1;</div>
|
||||
<div class="line"><a id="l00113" name="l00113"></a><span class="lineno"> 113</span> <span class="keywordtype">size_t</span> b_stride = 1;</div>
|
||||
<div class="line"><a id="l00114" name="l00114"></a><span class="lineno"> 114</span> <span class="keywordtype">bool</span> is_row_contiguous = <span class="keyword">true</span>;</div>
|
||||
<div class="line"><a id="l00115" name="l00115"></a><span class="lineno"> 115</span> <span class="keywordtype">bool</span> is_col_contiguous = <span class="keyword">true</span>;</div>
|
||||
<div class="line"><a id="l00116" name="l00116"></a><span class="lineno"> 116</span> </div>
|
||||
<div class="line"><a id="l00117" name="l00117"></a><span class="lineno"> 117</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {</div>
|
||||
<div class="line"><a id="l00118" name="l00118"></a><span class="lineno"> 118</span> is_row_contiguous &= strides[i] == f_stride || shape[i] == 1;</div>
|
||||
<div class="line"><a id="l00119" name="l00119"></a><span class="lineno"> 119</span> is_col_contiguous &= strides[ri] == b_stride || shape[ri] == 1;</div>
|
||||
<div class="line"><a id="l00120" name="l00120"></a><span class="lineno"> 120</span> f_stride *= shape[i];</div>
|
||||
<div class="line"><a id="l00121" name="l00121"></a><span class="lineno"> 121</span> b_stride *= shape[ri];</div>
|
||||
<div class="line"><a id="l00122" name="l00122"></a><span class="lineno"> 122</span> <span class="keywordflow">if</span> (strides[i] > 0) {</div>
|
||||
<div class="line"><a id="l00123" name="l00123"></a><span class="lineno"> 123</span> data_size *= shape[i];</div>
|
||||
<div class="line"><a id="l00124" name="l00124"></a><span class="lineno"> 124</span> }</div>
|
||||
<div class="line"><a id="l00125" name="l00125"></a><span class="lineno"> 125</span> }</div>
|
||||
<div class="line"><a id="l00126" name="l00126"></a><span class="lineno"> 126</span> </div>
|
||||
<div class="line"><a id="l00127" name="l00127"></a><span class="lineno"> 127</span> <span class="keywordflow">return</span> std::make_tuple(data_size, is_row_contiguous, is_col_contiguous);</div>
|
||||
<div class="line"><a id="l00128" name="l00128"></a><span class="lineno"> 128</span>}</div>
|
||||
</div>
|
||||
<div class="line"><a id="l00129" name="l00129"></a><span class="lineno"> 129</span> </div>
|
||||
<div class="line"><a id="l00130" name="l00130"></a><span class="lineno"> 130</span>} <span class="comment">// namespace mlx::core</span></div>
|
||||
<div class="ttc" id="aarray_8h_html"><div class="ttname"><a href="array_8h.html">array.h</a></div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1array_html"><div class="ttname"><a href="classmlx_1_1core_1_1array.html">mlx::core::array</a></div><div class="ttdef"><b>Definition</b> array.h:20</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1array_html_a0a20a6065ae71b64c1e3aa22a45fd8a1"><div class="ttname"><a href="classmlx_1_1core_1_1array.html#a0a20a6065ae71b64c1e3aa22a45fd8a1">mlx::core::array::flags</a></div><div class="ttdeci">const Flags & flags() const</div><div class="ttdoc">Get the Flags bit-field.</div><div class="ttdef"><b>Definition</b> array.h:290</div></div>
|
||||
@@ -223,9 +234,10 @@ $(function() { codefold.init(0); });
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1array_html_a4a2a2c8a4a5beafd723fc13f2055d55d"><div class="ttname"><a href="classmlx_1_1core_1_1array.html#a4a2a2c8a4a5beafd723fc13f2055d55d">mlx::core::array::shape</a></div><div class="ttdeci">const std::vector< int > & shape() const</div><div class="ttdoc">The shape of the array as a vector of integers.</div><div class="ttdef"><b>Definition</b> array.h:99</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html"><div class="ttname"><a href="namespacemlx_1_1core.html">mlx::core</a></div><div class="ttdef"><b>Definition</b> allocator.h:7</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_a4950c3248e70280b406a4f1430a85880"><div class="ttname"><a href="namespacemlx_1_1core.html#a4950c3248e70280b406a4f1430a85880">mlx::core::elem_to_loc</a></div><div class="ttdeci">stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)</div><div class="ttdef"><b>Definition</b> utils.h:12</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_a847b0a276663d9ddb5cac905ee977f03"><div class="ttname"><a href="namespacemlx_1_1core.html#a847b0a276663d9ddb5cac905ee977f03">mlx::core::check_contiguity</a></div><div class="ttdeci">auto check_contiguity(const std::vector< int > &shape, const std::vector< stride_t > &strides)</div><div class="ttdef"><b>Definition</b> utils.h:99</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_a9d151ba3e138be1954d2f51f85806b0c"><div class="ttname"><a href="namespacemlx_1_1core.html#a9d151ba3e138be1954d2f51f85806b0c">mlx::core::collapse_contiguous_dims</a></div><div class="ttdeci">std::tuple< std::vector< int >, std::vector< std::vector< stride_t > > > collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< std::vector< stride_t > > strides)</div><div class="ttdef"><b>Definition</b> utils.h:40</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_af89751d79339f3e4d9318ea97d64d114"><div class="ttname"><a href="namespacemlx_1_1core.html#af89751d79339f3e4d9318ea97d64d114">mlx::core::enable_for_arrays_t</a></div><div class="ttdeci">typename std::enable_if_t< is_arrays_v< T... > > enable_for_arrays_t</div><div class="ttdef"><b>Definition</b> array.h:565</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_a847b0a276663d9ddb5cac905ee977f03"><div class="ttname"><a href="namespacemlx_1_1core.html#a847b0a276663d9ddb5cac905ee977f03">mlx::core::check_contiguity</a></div><div class="ttdeci">auto check_contiguity(const std::vector< int > &shape, const std::vector< stride_t > &strides)</div><div class="ttdef"><b>Definition</b> utils.h:108</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_a9d151ba3e138be1954d2f51f85806b0c"><div class="ttname"><a href="namespacemlx_1_1core.html#a9d151ba3e138be1954d2f51f85806b0c">mlx::core::collapse_contiguous_dims</a></div><div class="ttdeci">std::tuple< std::vector< int >, std::vector< std::vector< stride_t > > > collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< std::vector< stride_t > > strides)</div><div class="ttdef"><b>Definition</b> utils.h:49</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_ac9fb1286a1a00395e901dbff80560895"><div class="ttname"><a href="namespacemlx_1_1core.html#ac9fb1286a1a00395e901dbff80560895">mlx::core::make_contiguous_strides</a></div><div class="ttdeci">std::vector< stride_t > make_contiguous_strides(const std::vector< int > &shape)</div><div class="ttdef"><b>Definition</b> utils.h:33</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_af89751d79339f3e4d9318ea97d64d114"><div class="ttname"><a href="namespacemlx_1_1core.html#af89751d79339f3e4d9318ea97d64d114">mlx::core::enable_for_arrays_t</a></div><div class="ttdeci">typename std::enable_if_t< is_arrays_v< T... > > enable_for_arrays_t</div><div class="ttdef"><b>Definition</b> array.h:566</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1array_1_1_flags_html_a3170fa381dc7a90f6eabcc029bdf9bfd"><div class="ttname"><a href="structmlx_1_1core_1_1array_1_1_flags.html#a3170fa381dc7a90f6eabcc029bdf9bfd">mlx::core::array::Flags::row_contiguous</a></div><div class="ttdeci">bool row_contiguous</div><div class="ttdef"><b>Definition</b> array.h:226</div></div>
|
||||
</div><!-- fragment --></div><!-- contents -->
|
||||
<!-- start footer part -->
|
||||
|
Reference in New Issue
Block a user