AI 算力資源越發(fā)緊張的當(dāng)下,斯坦福新研究將 GPU 運行效率再提升一波 —— 內(nèi)核只有 100 行代碼,讓 H100 比使用 FlashAttention-2,性能還要提升 30%。
怎么做到的?
研究人員從“硬件實際需要什么?如何滿足這些需求?”這兩個問題出發(fā),設(shè)計了 一個嵌入式 CUDA DSL 工具,名為 ThunderKittens(暫且譯為雷貓)。
雷貓可簡化 AI 內(nèi)核的編寫,同時充分利用底層硬件能力。
網(wǎng)友們對此討論也十分熱烈。
有網(wǎng)友表示讀這篇 Blog 時,讓他想起了初次了解超標(biāo)量 CPU 架構(gòu)時的驚訝感受:
GPU 真的達(dá)到了新高度。
所以要充分發(fā)揮 H100 的能力,關(guān)鍵是保持張量核心持續(xù)運算。
然鵝,要保持張量核心持續(xù)運行并不容易。
研究人員發(fā)現(xiàn) GPU 硬件具有一些特性,對于保持矩陣乘法的運行非常重要:
WGMMA 指令雖然是必要的,但使用起來頗為麻煩。
共享內(nèi)存的速度并不如預(yù)期的快,使用時還需格外注意。
生成地址的成本較高。
保持高占用率對于提升性能是有益的,寄存器至關(guān)重要。
這些特性在非 H100 GPU 上也有所適用,在 H100 上更加典型,就拿 RTX 4090 來說,相比 H100 處理起來簡單得多。
最終發(fā)現(xiàn),這些布局只適用于特定矩陣形狀,并與 wgmma.mma_async 指令的其他部分不兼容,例如硬件僅在未重排的布局下轉(zhuǎn)置子矩陣。
此外,未重排的 wgmma 布局內(nèi)存合并性差且有 bank conflicts。盡管 TMA 和 L2 緩存在如 flash attention 這類內(nèi)核上能較好地掩蓋這些問題,但要充分利用硬件,必須精心控制內(nèi)存請求的合并和避免 bank conflicts。
盡管有這些問題,但這些指令對于充分利用 H100 是必不可少的。沒有它們,GPU 的潛在性能就損失了 37%。
共享內(nèi)存的單次訪問延遲約為 30 個周期(這也與研究人員觀察的相符),這看似不多,但在這段時間內(nèi),SM 的張量核心幾乎能完成兩次完整的 32x32 方陣乘法。
以前的研究,如 Flash Attention,研究人員更多關(guān)注的是 HBM-SRAM 的瓶頸。但隨著 HBM 速度的提升和張量核心的快速發(fā)展,即使是共享內(nèi)存的相對較小延遲也變得尤為關(guān)鍵。
由于共享內(nèi)存被分為 32 個獨立的存儲單元,處理不當(dāng)可能會引發(fā) bank conflicts,即同一個內(nèi)存 bank 同時被多個請求訪問,這種情況會導(dǎo)致請求被序列化。研究人員實驗后認(rèn)為,這會顯著拖慢內(nèi)核速度,且 wgmma 與 mma 指令需要的寄存器布局容易受到 bank conflicts 的影響。
解決方法是通過各種“重排”模式調(diào)整共享內(nèi)存的配置,避免 bank conflicts,但細(xì)節(jié)要處理得當(dāng)。
此外研究人員發(fā)現(xiàn),盡可能避免在寄存器和共享內(nèi)存之間的移動數(shù)據(jù)非常重要??赡艿脑挘墒褂脙?nèi)置硬件(如 wgmma 和 TMA 指令)進(jìn)行異步數(shù)據(jù)傳輸。實在沒法子了,再使用 warp 進(jìn)行同步數(shù)據(jù)傳輸。
H100 還有一個有趣的特性,其張量核心和內(nèi)存都足夠快,以至于僅生成用于獲取數(shù)據(jù)的內(nèi)存地址就占用了芯片的大量資源,特別是加入復(fù)雜的交錯或重排模式時,這種情況更為明顯。
研究人員表示,英偉達(dá)提供了張量內(nèi)存加速器(TMA),似乎就是已經(jīng)意識到了這個問題。
TMA 允許用戶在全局和共享內(nèi)存中指定多維張量布局,命令其異步提取張量的一部分,并在完成后觸發(fā)一個屏障。這大大節(jié)省了地址生成的開銷,并簡化了 pipelines 的構(gòu)建。
研究人員認(rèn)為,TMA 對于充分發(fā)揮 H100 的潛力至關(guān)重要,可能比 wgmma.mma_async 更為關(guān)鍵。
它不僅節(jié)省了寄存器資源和指令派發(fā),還提供了如異步在全局內(nèi)存上執(zhí)行歸約等實用功能 —— 這在處理復(fù)雜的反向內(nèi)核時尤其有用。
雖然 TMA 的重排模式解讀有一定難度,需要進(jìn)行一些逆向工程,但研究人員表示,相比之下,他們在這上面遇到的問題要少得多。
占用率指的是在 GPU 的相同執(zhí)行硬件上同時調(diào)度的線程數(shù)。每個周期,SM 的某一子單元的 warp scheduler 會嘗試向準(zhǔn)備就緒的 warp 線程發(fā)出指令。
研究人員認(rèn)為,英偉達(dá)采用這種模型可以更容易地保持硬件的滿負(fù)荷運行。例如,當(dāng)一個線程 warp 等待執(zhí)行矩陣乘法時,另一個可以被指派執(zhí)行使用快速指數(shù)運算的指令。
在某些方面,H100 對占用率的依賴程度低于前幾代硬件。
它的異步特性使得即使單一指令流也能使多個硬件部分同時持續(xù)運行,包括讀取內(nèi)存、執(zhí)行矩陣乘法、進(jìn)行共享內(nèi)存的歸約,同時還能在寄存器上進(jìn)行計算。
但高占用率容易隱藏缺陷或同步問題,一個設(shè)計良好的 pipeline 即使在占用率不高的情況下也能運行得相當(dāng)快。
據(jù)研究人員觀察,英偉達(dá)在設(shè)計 GPU 時確實考慮到了占用率。且由于存在足夠多的同步操作和足夠多的錯誤可能性,根據(jù)他們的經(jīng)驗,提高占用率通常能顯著增加硬件的實際利用率。
此外,相比 H100,A100 和 RTX 4090 更依賴同步指令調(diào)度,占用率更重要。
鑒于以上情況,如何才能更輕松地編寫所需的內(nèi)核類型,同時充分發(fā)揮硬件的全部潛力?
雷貓(ThunderKittens)登場了。
這是一個嵌入在 CUDA 中的 DSL,本是斯坦福研究人員設(shè)計出來給自己內(nèi)部使用的,后來發(fā)現(xiàn)還真挺好使。
Ps:起這么個名,一是他們覺得小貓很可愛,二來他們覺得大伙兒在代碼中輸入 kittens:: 會很有趣。
具體來說,雷貓包含四種模板類型:
tiles 通過高度、寬度和布局進(jìn)行參數(shù)化;寄存器向量通過長度和布局進(jìn)行參數(shù)化;而共享向量僅通過長度進(jìn)行參數(shù)化,通常不會遇到 bank conflicts 問題。
此外,研究人員提供了一系列操作來處理這些張量,既可在 warp 級別使用,也可用于多個 warp 協(xié)作,包含初始化器,如將共享向量清零;一元操作,如 exp;二元操作,如 mul;行 / 列操作,例如行求和。
雷貓作為一個嵌入到 CUDA 中的庫,其提供的抽象層在遇到不支持的功能時能夠很好地處理。如果雷貓缺少某些功能,可以直接擴(kuò)展它來實現(xiàn)你想要的效果。
以 Tri 的 flash attention 算法為例,在實際應(yīng)用中,即使是使用英偉達(dá)的 Cutlass 庫,實現(xiàn)起來也是相當(dāng)復(fù)雜。
以下是一個在 RTX 4090 上使用雷貓編寫的簡單 flash attention 內(nèi)核的示例。
總共約 60 行 CUDA 代碼,硬件利用率達(dá)到了 75%。代碼復(fù)雜性主要在于算法本身,而非交織模式或寄存器布局。
#define NUM_WORKERS 16 // This kernel uses 16 workers in parallel per block, to help issue instructions more quickly. using namespace kittens; // this kernel only handles headdim=64 for simplicity. Also n should be a multiple of 256 here. __global__ void attend_ker64(int n, const bf16* __restrict__ __q__, const bf16* __restrict__ __k__, const bf16* __restrict__ __v__, bf16* __o__) { auto warpid = kittens::warpid(); auto block_start = blockIdx.x*(n*64); const bf16 *_q = __q__ + block_start, *_k = __k__ + block_start, *_v = __v__ + block_start; bf16 *_o = __o__ + block_start; extern __shared__ alignment_dummy __shm[]; // this is the CUDA shared memory shared_allocator al((int*)&__shm[0]); // K and V live in shared memory -- this is about all that will fit. st_bf_1x4<ducks::st_layout::swizzle> (&k_smem)[NUM_WORKERS] = al.allocate<st_bf_1x4<ducks::st_layout::swizzle>, NUM_WORKERS>(); st_bf_1x4<ducks::st_layout::swizzle> (&v_smem)[NUM_WORKERS] = al.allocate<st_bf_1x4<ducks::st_layout::swizzle>, NUM_WORKERS>(); // Initialize all of the register tiles. rt_bf_1x4<> q_reg, k_reg, v_reg; // v_reg need to be swed into col_l rt_fl_1x1<> att_block; rt_bf_1x1<> att_block_mma; rt_fl_1x4<> o_reg; rt_fl_1x1<>::col_vec max_vec_last, max_vec; // these are column vectors for the attention block rt_fl_1x1<>::col_vec norm_vec_last, norm_vec; // these are column vectors for the attention block int qo_blocks = n / (q_reg.rows*NUM_WORKERS), kv_blocks = n / (q_reg.rows*NUM_WORKERS); for(auto q_blk = 0; q_blk < qo_blocks; q_blk++) { // each warp loads its own Q tile of 16x64, and then multiplies by 1/sqrt(d) load(q_reg, _q + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols); mul(q_reg, q_reg, __float2bfloat16(0.125f)); // temperature adjustment // zero flash attention L, M, and O registers. neg_infty(max_vec); // zero registers for the Q chunk zero(norm_vec); zero(o_reg); // iterate over k, v for these q's that have been loaded for(auto kv_idx = 0; kv_idx < kv_blocks; kv_idx++) { // each warp loads its own chunk of k, v into shared memory load(v_smem[warpid], _v + (kv_idx*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols); load(k_smem[warpid], _k + (kv_idx*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols); __syncthreads(); // we need to make sure all memory is loaded before we can begin the compute phase // now each warp goes through all of the subtiles, loads them, and then does the flash attention internal alg. for(int subtile = 0; subtile < NUM_WORKERS; subtile++) { load(k_reg, k_smem[subtile]); // load k from shared into registers zero(att_block); // zero 16x16 attention tile mma_ABt(att_block, q_reg, k_reg, att_block); // [email protected] copy(norm_vec_last, norm_vec); copy(max_vec_last, max_vec); row_max(max_vec, att_block, max_vec); // accumulate onto the max_vec sub_row(att_block, att_block, max_vec); // subtract max from attention -- now all <=0 exp(att_block, att_block); // exponentiate the block in-place. sub(max_vec_last, max_vec_last, max_vec); // subtract new max from old max to find the new normalization. exp(max_vec_last, max_vec_last); // exponentiate this vector -- this is what we need to normalize by. mul(norm_vec, norm_vec, max_vec_last); // and the norm vec is now normalized. row_sum(norm_vec, att_block, norm_vec); // accumulate the new attention block onto the now-rescaled norm_vec div_row(att_block, att_block, norm_vec); // now the attention block is correctly normalized mul(norm_vec_last, norm_vec_last, max_vec_last); // normalize the previous norm vec according to the new max div(norm_vec_last, norm_vec_last, norm_vec); // normalize the previous norm vec according to the new norm copy(att_block_mma, att_block); // convert to bf16 for mma_AB load(v_reg, v_smem[subtile]); // load v from shared into registers. rt_bf_1x4<ducks::rt_layout::col> &v_reg_col = swap_layout_inplace(v_reg); // this is a reference and the call has invalidated v_reg mul_row(o_reg, o_reg, norm_vec_last); // normalize o_reg in advance of mma_AB'ing onto it mma_AB(o_reg, att_block_mma, v_reg_col, o_reg); // mfma onto o_reg with the local attention@V matmul. } __syncthreads(); // we need to make sure all warps are done before we can start loading the next kv chunk } store(_o + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, o_reg, q_reg.cols); // write out o. compiler has an issue with register usage if d is made constexpr q_reg.rows :/ } }
關(guān)于 TMA、WGMMA、交織模式和描述符的復(fù)雜性,這里展示了一個使用雷貓編寫的,針對 H100 的 FlashAttention-2 算法的前向傳遞示例。
template<int D> __global__ __launch_bounds__((NUM_WORKERS)*kittens::WARP_THREADS, 2) void fwd_attend_ker_dim(int N, const CUtensorMap* tma_q, const CUtensorMap* tma_k, const CUtensorMap* tma_v, CUtensorMap* tma_o) { extern __shared__ int __shm[]; // this is the CUDA shared memory tma_swizzle_allocator al((int*)&__shm[0]); constexpr int tile_width = fwd_attend_ker_tile_dims<D>::tile_width; // constants constexpr int qo_height = fwd_attend_ker_tile_dims<D>::qo_height; constexpr int kv_height = fwd_attend_ker_tile_dims<D>::kv_height; st_bf<qo_height, tile_width, layout_q> (&q_smem) [NUM_WARPGROUPS] = al.allocate<st_bf<qo_height, tile_width, layout_q>, NUM_WARPGROUPS>(); st_bf<kv_height, tile_width, layout_k> (&k_smem)[2][NUM_WORKERS_KV] = al.allocate<st_bf<kv_height, tile_width, layout_k>, 2, NUM_WORKERS_KV>(); st_bf<kv_height, tile_width, layout_v> (&v_smem)[2][NUM_WORKERS_KV] = al.allocate<st_bf<kv_height, tile_width, layout_v>, 2, NUM_WORKERS_KV>(); int tic = 0, toc = 1; rt_fl<1, kv_height> att_block; rt_bf<1, kv_height> att_block_mma; rt_fl<1, qo_height> o_prev; col_vec<rt_fl<1, kv_height>> max_vec_last, max_vec; col_vec<rt_fl<1, kv_height>> norm_vec_last, norm_vec; int warpid = kittens::warpid(); int warpgroupid = warpid/kittens::WARPGROUP_WARPS; int kv_blocks = N / (NUM_WORKERS_KV*k_smem[0][0].rows); __shared__ uint64_t qsmem_barrier, kvsmem_barrier;//, vsmem_barrier; int q_phasebit = 0; int kv_phasebit = 0; if (threadIdx.x == 0) { tma::init_barrier<st_bf<qo_height, tile_width, layout_q>, NUM_WARPGROUPS>(qsmem_barrier, 1); tma::init_barrier<st_bf<kv_height, tile_width, layout_k>, NUM_WORKERS_KV*2>(kvsmem_barrier, 1); } if (warpid == 0) { for (int wg = 0; wg < NUM_WORKERS/kittens::WARPGROUP_WARPS; wg++) { // load q int tile_idx = (blockIdx.y * NUM_WARPGROUPS * gridDim.x) + (blockIdx.x * NUM_WARPGROUPS) + wg; tma::load_async((q_smem[wg]), tma_q, qsmem_barrier, tile_idx); } for (int w = 0; w < NUM_WORKERS_KV; w++) { // load k, v int tile_idx = (blockIdx.y * NUM_WORKERS_KV * kv_blocks) + (0 * NUM_WORKERS_KV) + w; tma::load_async((k_smem[tic][w]), tma_k, kvsmem_barrier, tile_idx); tma::load_async((v_smem[tic][w]), tma_v, kvsmem_barrier, tile_idx); } } neg_infty(max_vec); // zero registers for the Q chunk zero(norm_vec); zero(o_prev); __syncthreads(); tma::arrive_and_wait(qsmem_barrier, q_phasebit); q_phasebit ^= 1; if constexpr (D == 64) { warpgroup::mul(q_smem[warpgroupid], q_smem[warpgroupid], __float2bfloat16(0.125f)); } else { warpgroup::mul(q_smem[warpgroupid], q_smem[warpgroupid], __float2bfloat16(0.08838834764f)); } for (auto kv_idx = 0; kv_idx < kv_blocks; kv_idx++, tic ^= 1, toc ^= 1) { tma::arrive_and_wait(kvsmem_barrier, kv_phasebit); kv_phasebit ^= 1; __syncthreads(); if (warpid == 0) { tma::set_bytes(kvsmem_barrier, 2 * NUM_WORKERS_KV * k_smem[0][0].num_elements * sizeof(bf16)); if (kv_idx + 1 < kv_blocks) { for (int w = 0; w < NUM_WORKERS_KV; w++) { int tile_idx = (blockIdx.y * NUM_WORKERS_KV * kv_blocks) + ((kv_idx + 1) * NUM_WORKERS_KV) + w; tma::load_async((k_smem[toc][w]), tma_k, kvsmem_barrier, tile_idx); tma::load_async((v_smem[toc][w]), tma_v, kvsmem_barrier, tile_idx); } } } warpgroup::mma_fence(att_block); warpgroup::mm_ABt(att_block, q_smem[warpgroupid], k_smem[tic][0]); warpgroup::mma_commit_group(); copy(norm_vec_last, norm_vec); copy(max_vec_last, max_vec); warpgroup::mma_async_wait(); row_max(max_vec, att_block, max_vec); // accumulate onto the max_vec sub_row(att_block, att_block, max_vec); exp(att_block, att_block); sub(max_vec_last, max_vec_last, max_vec); exp(max_vec_last, max_vec_last); mul(norm_vec, norm_vec, max_vec_last); row_sum(norm_vec, att_block, norm_vec); // accumulate onto the norm_vec div_row(att_block, att_block, norm_vec); mul(norm_vec_last, norm_vec_last, max_vec_last); div(norm_vec_last, norm_vec_last, norm_vec); copy(att_block_mma, att_block); // convert to bf16 for mma mul_row(o_prev, o_prev, norm_vec_last); // normalize o_prev in advance of mma'ing onto it warpgroup::mma_fence(o_prev); warpgroup::mma_AB(o_prev, att_block_mma, v_smem[tic][0]); warpgroup::mma_commit_group(); } auto (*o_smem) = reinterpret_cast<st_bf<qo_height, tile_width, layout_o>(*)>(q_smem); // reuse q memory warpgroup::store(o_smem[warpgroupid], o_prev); __syncthreads(); if (warpid % 4 == 0) { // store o int tile_idx = (blockIdx.y * NUM_WARPGROUPS * gridDim.x) + (blockIdx.x * NUM_WARPGROUPS) + warpgroupid; tma::store_async(tma_o, (o_smem[warpgroupid]), tile_idx); tma::store_commit_group(); } tma::store_async_wait(); }
那么,它的表現(xiàn)如何?
這個內(nèi)核只有 100 行代碼,實際上它在 H100 上的性能比 FlashAttention-2 高出約 30%。雷貓負(fù)責(zé)包裝布局和指令,提供了一個可以在 GPU 上使用的迷你 pytorch 環(huán)境。
△FA2(通過 Pytorch 實現(xiàn))與 TK 在 H100 SXM 上的多種配置比較
此外,研究人員還發(fā)布了基于線性注意力和其他新架構(gòu)的內(nèi)核。其中基于線性注意力的內(nèi)核的運行速度可達(dá) 215 TFLOPs,如果考慮到算法中固有的重計算,速度可超過 300 TFLOPs。
盡管線性注意力在理論上效率更高,但此前在實際硬件上表現(xiàn)并不佳。因此,研究人員認(rèn)為這可能促進(jìn)一系列高吞吐量應(yīng)用的發(fā)展。
最后,雷貓研究團(tuán)隊總結(jié)了開發(fā)雷貓的一些思考。在他們看來,雷貓之所以有效,是因為它的目標(biāo)并不是試圖做所有事:
CUDA 的確比雷貓表達(dá)能力更廣,雷貓小而簡單,功能有限。但雷貓的 small tiles 抽象設(shè)計符合 AI 和硬件的發(fā)展趨勢。
雖然雷貓不支持小于 16 的維度,但研究人員認(rèn)為這并不重要,因為硬件也不傾向于支持過小的維度。
如果你的矩陣乘法小于 16x16,你確定你正在做的是 AI 嗎?
從理論出發(fā),研究人員認(rèn)為需要進(jìn)行一種框架轉(zhuǎn)變。
“寄存器當(dāng)然不應(yīng)該像舊 CPU 那樣 32 位字。CUDA 使用的 1024 位寬向量寄存器確實是朝著正確方向邁出的一步。但對我們來說,寄存器是 16x16 的數(shù)據(jù) tile。我們認(rèn)為 AI 需要這樣的設(shè)計,畢竟,它仍然只是矩陣乘法、規(guī)約和重塑。我們認(rèn)為硬件也需要這樣的設(shè)計,小型矩陣乘法迫切需要超出系統(tǒng)級 MMA 的硬件支持?!?/p>
研究人員認(rèn)為,應(yīng)該根據(jù)硬件特性來重新定義 AI 的設(shè)計理念。例如,循環(huán)狀態(tài)應(yīng)該有多大?應(yīng)該足夠大以適應(yīng)一個 SM。計算的密度應(yīng)該有多高?不應(yīng)低于硬件的需求。
我們未來工作的一個重要方向是利用我們對硬件的了解來幫助我們設(shè)計與之匹配的 AI。
還沒有評論,來說兩句吧...