mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
rebase
This commit is contained in:
@@ -236,8 +236,8 @@ $(function(){ initResizable(false); });
|
||||
<div class="line"><a id="l00142" name="l00142"></a><span class="lineno"> 142</span> <span class="comment">// Store results to device memory</span></div>
|
||||
<div class="line"><a id="l00143" name="l00143"></a><span class="lineno"> 143</span> {</div>
|
||||
<div class="line"><a id="l00144" name="l00144"></a><span class="lineno"> 144</span> <span class="comment">// Adjust for simdgroup and thread locatio</span></div>
|
||||
<div class="line"><a id="l00145" name="l00145"></a><span class="lineno"> 145</span> <span class="keywordtype">int</span> offset_m = c_row + mma_op.sm + mma_op.tm;</div>
|
||||
<div class="line"><a id="l00146" name="l00146"></a><span class="lineno"> 146</span> <span class="keywordtype">int</span> offset_n = c_col + mma_op.sn + mma_op.tn;</div>
|
||||
<div class="line"><a id="l00145" name="l00145"></a><span class="lineno"> 145</span> <span class="keywordtype">int</span> offset_m = c_row + mma_op.sm;</div>
|
||||
<div class="line"><a id="l00146" name="l00146"></a><span class="lineno"> 146</span> <span class="keywordtype">int</span> offset_n = c_col + mma_op.sn;</div>
|
||||
<div class="line"><a id="l00147" name="l00147"></a><span class="lineno"> 147</span> C += offset_n;</div>
|
||||
<div class="line"><a id="l00148" name="l00148"></a><span class="lineno"> 148</span> </div>
|
||||
<div class="line"><a id="l00149" name="l00149"></a><span class="lineno"> 149</span> <span class="keywordflow">if</span> (offset_n >= gemm_params->N)</div>
|
||||
@@ -263,17 +263,17 @@ $(function(){ initResizable(false); });
|
||||
<div class="line"><a id="l00169" name="l00169"></a><span class="lineno"> 169</span> <a class="code hl_define" href="steel_2defines_8h.html#a5a5c3095b132a7589bc19cd5cb80e2c6">STEEL_PRAGMA_UNROLL</a></div>
|
||||
<div class="line"><a id="l00170" name="l00170"></a><span class="lineno"> 170</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> j = 0; j < mma_t::TN; j++) {</div>
|
||||
<div class="line"><a id="l00171" name="l00171"></a><span class="lineno"> 171</span> <span class="comment">// Get accumulated result and associated offset in C</span></div>
|
||||
<div class="line"><a id="l00172" name="l00172"></a><span class="lineno"> 172</span> thread <span class="keyword">const</span> <span class="keyword">auto</span>& accum =</div>
|
||||
<div class="line"><a id="l00173" name="l00173"></a><span class="lineno"> 173</span> mma_op.results[i * mma_t::TN + j].thread_elements();</div>
|
||||
<div class="line"><a id="l00174" name="l00174"></a><span class="lineno"> 174</span> <span class="keywordtype">int</span> offset = offset_cm + (j * mma_t::TN_stride);</div>
|
||||
<div class="line"><a id="l00175" name="l00175"></a><span class="lineno"> 175</span> </div>
|
||||
<div class="line"><a id="l00176" name="l00176"></a><span class="lineno"> 176</span> <span class="comment">// Apply epilogue and output C</span></div>
|
||||
<div class="line"><a id="l00177" name="l00177"></a><span class="lineno"> 177</span> <span class="keywordflow">if</span> (j * mma_t::TN_stride < diff) {</div>
|
||||
<div class="line"><a id="l00178" name="l00178"></a><span class="lineno"> 178</span> C[offset] = Epilogue::apply(accum[0]);</div>
|
||||
<div class="line"><a id="l00179" name="l00179"></a><span class="lineno"> 179</span> }</div>
|
||||
<div class="line"><a id="l00180" name="l00180"></a><span class="lineno"> 180</span> </div>
|
||||
<div class="line"><a id="l00181" name="l00181"></a><span class="lineno"> 181</span> <span class="keywordflow">if</span> (j * mma_t::TN_stride + 1 < diff) {</div>
|
||||
<div class="line"><a id="l00182" name="l00182"></a><span class="lineno"> 182</span> C[offset + 1] = Epilogue::apply(accum[1]);</div>
|
||||
<div class="line"><a id="l00172" name="l00172"></a><span class="lineno"> 172</span> thread <span class="keyword">const</span> <span class="keyword">auto</span>& accum = mma_op.Ctile.frag_at(i, j);</div>
|
||||
<div class="line"><a id="l00173" name="l00173"></a><span class="lineno"> 173</span> <span class="keywordtype">int</span> offset = offset_cm + (j * mma_t::TN_stride);</div>
|
||||
<div class="line"><a id="l00174" name="l00174"></a><span class="lineno"> 174</span> </div>
|
||||
<div class="line"><a id="l00175" name="l00175"></a><span class="lineno"> 175</span> <span class="keyword">constexpr</span> <span class="keywordtype">short</span> kelems = <span class="keyword">decltype</span>(mma_op.Ctile)::kElemsPerFrag;</div>
|
||||
<div class="line"><a id="l00176" name="l00176"></a><span class="lineno"> 176</span> </div>
|
||||
<div class="line"><a id="l00177" name="l00177"></a><span class="lineno"> 177</span> <span class="comment">// Apply epilogue and output C</span></div>
|
||||
<div class="line"><a id="l00178" name="l00178"></a><span class="lineno"> 178</span> <a class="code hl_define" href="steel_2defines_8h.html#a5a5c3095b132a7589bc19cd5cb80e2c6">STEEL_PRAGMA_UNROLL</a></div>
|
||||
<div class="line"><a id="l00179" name="l00179"></a><span class="lineno"> 179</span> <span class="keywordflow">for</span> (<span class="keywordtype">short</span> k = 0; k < kelems; k++) {</div>
|
||||
<div class="line"><a id="l00180" name="l00180"></a><span class="lineno"> 180</span> <span class="keywordflow">if</span> ((j * mma_t::TN_stride + k) < diff) {</div>
|
||||
<div class="line"><a id="l00181" name="l00181"></a><span class="lineno"> 181</span> C[offset + k] = Epilogue::apply(accum[k]);</div>
|
||||
<div class="line"><a id="l00182" name="l00182"></a><span class="lineno"> 182</span> }</div>
|
||||
<div class="line"><a id="l00183" name="l00183"></a><span class="lineno"> 183</span> }</div>
|
||||
<div class="line"><a id="l00184" name="l00184"></a><span class="lineno"> 184</span> }</div>
|
||||
<div class="line"><a id="l00185" name="l00185"></a><span class="lineno"> 185</span> }</div>
|
||||
|
Reference in New Issue
Block a user