PyTorch 2.0發(fā)布了!一行代碼提速76%
來(lái)源:
奇酷教育 發(fā)表於:
PyTorch 2 0發(fā)布了!一行代碼提速76%
12月2日,PyTorch 2.0正式發(fā)布!
這次的更新不僅將PyTorch的性能推到了新的高度,同時(shí)也加入了對(duì)動(dòng)態(tài)形狀和分布式的支持。
此外,2.0系列還會(huì)將PyTorch的部分代碼從C++移回Python。
目前,PyTorch 2.0還處在測(cè)試階段,預計(jì)第一個(gè)穩(wěn)定版本會(huì)在2023年3月初面世。
PyTorch 2.x:更快、更Python!
在過(guò)去的幾年裡,PyTorch從1.0到最近的1.13進(jìn)行了創(chuàng)新和迭代,並轉(zhuǎn)移到新成立的PyTorch基金會(huì),成為Linux基金會(huì)的一部分。
當(dāng)前版本的PyTorch所面臨的挑戰是,eager-mode難以跟上不斷增長(zhǎng)的GPU帶寬和更瘋狂的模型架構(gòu )。
而PyTorch 2.0的誕生,將從根本上改變和提升了PyTorch在編譯器級(jí)別下的運(yùn)行方式。
眾所周知,PyTorch中的(Py)來(lái)自於數(shù)據(jù)科學(xué)中廣(guǎng)泛使用的開(kāi)源Python程式語言。
然而,PyTorch的代碼卻並沒有完全採用Python,而是把一部分交給了C++。
不過(guò),在今後的2.x系列中,PyTorch項(xiàng)目團(tuán)隊(duì)計(jì)劃(huà)將與torch.nn有關(guān)的代碼移回到Python中。
除此之外,由於PyTorch 2.0是一個(gè)完全附加的(和可選的)功能,因此2.0是100%向後兼容的。
也就是說(shuō),代碼庫(kù)是一樣的,API也是一樣的,編寫(xiě)模型的方式也是一樣的。
更多的技術(shù)支持
TorchDynamo
使用Python框架評(píng)估鉤子安全地捕獲PyTorch程序,這是團(tuán)隊(duì)5年來(lái)在graph capture方面研發(fā)的一項(xiàng)重大創(chuàng)新。
AOTAutograd
重載了PyTorch的autograd引擎,作為一個(gè)追蹤的autodiff,用於生成超前的反向追蹤。
PrimTorch
將約2000多個(gè)PyTorch運(yùn)算符歸納為約250個(gè)原始運(yùn)算符的封閉集,開(kāi)發(fā)人員可以針對(duì)這些運(yùn)算符構(gòu )建一個(gè)完整的PyTorch後端。大大降低了編寫(xiě)PyTorch功能或後端的障礙。
TorchInductor
一個(gè)深度學(xué)習(xí)編譯器,可以為多個(gè)加速器和後端生成快速代碼。對(duì)於英偉達(dá)的GPU,它使用OpenAI Triton作為關(guān)鍵構(gòu )建模塊。
值得注意的是,TorchDynamo、AOTAutograd、PrimTorch和TorchInductor都是用Python編寫(xiě)的,並支持動(dòng)態(tài)形狀。
更快的訓(xùn)練速度
通過(guò)引入新的編譯模式「torch.compile」,PyTorch 2.0用一行代碼,就可以加速模型的訓(xùn)練。
這裡不用任何技巧,只需運(yùn)行torch.compile()即可,僅此而已:
opt_module = torch.compile(module)
為了驗(yàn)證這些技術(shù),團(tuán)隊(duì)精心打造了測(cè)試基準(zhǔn),包括圖像分類(lèi)、物體檢測(cè)、圖像生成等任務(wù),以及各種NLP任務(wù),如語言建模、問(wèn)答、序列分類(lèi)、推薦系統(tǒng)和強(qiáng)化學(xué)習(xí)。其中,這些基準(zhǔn)可以分為三類(lèi):
來(lái)自HuggingFace Transformers的46個(gè)模型
來(lái)自TIMM的61個(gè)模型:Ross Wightman收集的最先進(jìn)的PyTorch圖像模型
來(lái)自TorchBench的56個(gè)模型:github的一組流行代碼庫(kù)
測(cè)試結(jié)果表明,在這163個(gè)跨越視覺、NLP和其他領(lǐng)域的開(kāi)源模型上,訓(xùn)練速度得到了38%-76%的提高。
在NVIDIA A100 GPU上的對(duì)比
此外,團(tuán)隊(duì)還在一些流行的開(kāi)源PyTorch模型上進(jìn)行了基準(zhǔn)測(cè)試,並獲得了從30%到2倍的大幅加速。
開(kāi)發(fā)者Sylvain Gugger表示:「只需添加一行代碼,PyTorch 2.0就能在訓(xùn)練Transformers模型時(shí)實(shí)現(xiàn)1.5倍到2.0倍的速度提升。這是自混合精度訓(xùn)練問(wèn)世以來(lái)最令人興奮的事情!」
團(tuán)隊(duì)之所以稱(chēng)它為 2.0,是因為它有一些標(biāo)誌性的新特性,包括:
TorchDynamo 可以從字節(jié)碼分析生成 FX 圖;
AOTAutograd 可以以 ahead-of-time 的方式生成反向圖;
PrimTorch 引入了一個(gè)小型算子集,使後端更容易;
TorchInductor:一個(gè)由 OpenAI Triton 支持的 DL 編譯器。
PyTorch 2.0 將延續(xù) PyTorch 一貫的優(yōu)勢(shì),包括 Python 集成、命令式風(fēng)格、API 簡(jiǎn)單等等。此外,PyTorch 2.0 提供了相同的 eager-mode 開(kāi)發(fā)和用戶(hù)體驗(yàn),同時(shí)從根本上改變和增強(qiáng)了 PyTorch 在編譯器級(jí)別的運(yùn)行方式。該版本能夠為「Dynamic Shapes」和分布式運(yùn)行提供更快的性能和更好的支持。
在官方博客中,PyTorch團(tuán)隊(duì)還公布了他們對(duì)於整個(gè)2.0系列的展望:
以下是詳細(xì)內(nèi)容。
PyTorch 2.X:速度更快、更加地 Python 化、一如既往地 dynamic
PyTorch 2.0 官宣了一個(gè)重要特性——torch.compile,這一特性將 PyTorch 的性能推向了新的高度,並將 PyTorch 的部分內(nèi)容從 C++ 移回 Python。torch.compile 是一個(gè)完全附加的(可選的)特性,因此 PyTorch 2.0 是 100% 向後兼容的。
支撐 torch.compile 的技術(shù)包括研發(fā)團(tuán)隊(duì)新推出的 TorchDynamo、AOTAutograd、PrimTorch 和 TorchInductor。
TorchDynamo 使用 Python Frame Evaluation Hooks 安全地捕獲 PyTorch 程序,這是一項(xiàng)重大創(chuàng)新,是研究團(tuán)隊(duì)對(duì)快速可靠地獲取圖進(jìn)行 5 年研發(fā)的結(jié)果;
AOTAutograd 重載 PyTorch 的 autograd 引擎作為一個(gè)跟蹤 autodiff,用於生成 ahead-of-time 向後跟蹤;
PrimTorch 將約 2000 多個(gè) PyTorch 算子規(guī)範化為一組約 250 個(gè)原始算子的閉集,開(kāi)發(fā)人員可以將其作為構(gòu )建完整 PyTorch 後端的目標(biāo)。這大大降低了編寫(xiě) PyTorch 特性或後端的障礙;
TorchInductor 是一種深度學(xué)習(xí)編譯器,可為多個(gè)加速器和後端生成快速代碼。對(duì)於 NVIDIA GPU,它使用 OpenAI Triton 作為關(guān)鍵構(gòu )建塊。
TorchDynamo、AOTAutograd、PrimTorch 和 TorchInductor 是用 Python 編寫(xiě)的,並支持 dynamic shapes(即能夠發(fā)送不同大小的張量而無(wú)需重新編譯),這使得它們具備靈活、易於破解的特性,降低了開(kāi)發(fā)人員和供應(yīng)商的使用門(mén)檻。
為了驗(yàn)證這些技術(shù),研發(fā)團(tuán)隊(duì)在各種機(jī)器學(xué)習(xí)領(lǐng)域測(cè)試了 163 個(gè)開(kāi)源模型。實(shí)驗(yàn)精心構(gòu )建了測(cè)試基準(zhǔn),包括各種 CV 任務(wù)(圖像分類(lèi)、目標(biāo)檢測(cè)、圖像生成等)、NLP 任務(wù)(語言建模、問(wèn)答、序列分類(lèi)、推薦系統(tǒng)等)和強(qiáng)化學(xué)習(xí)任務(wù),測(cè)試模型主要有 3 個(gè)來(lái)源:
46 個(gè)來(lái)自 HuggingFace Transformers 的模型;
來(lái)自 TIMM 的 61 個(gè)模型:一系列 SOTA PyTorch 圖像模型;
來(lái)自 TorchBench 的 56 個(gè)模型:包含來(lái)自 github 的精選流行代碼庫(kù)。
然後研究者測(cè)量加速性能並驗(yàn)證這些模型的準(zhǔn)確性。加速可能取決於數(shù)據(jù)類(lèi)型,研究團(tuán)隊(duì)選擇測(cè)量 float32 和自動(dòng)混合精度 (AMP) 的加速。
在 163 個(gè)開(kāi)源模型中,torch.compile 在 93% 的情況下都有效,模型在 NVIDIA A100 GPU 上的訓(xùn)練速度提高了 43%。在 float32 精度下,它的平均運(yùn)行速度提高了 21%,而在 AMP 精度下,它的運(yùn)行速度平均提高了 51%。
目前,torch.compile 還處於早期開(kāi)發(fā)階段,預計(jì) 2023 年 3 月上旬將發(fā)布第一個(gè)穩(wěn)定的 2.0 版本。
TorchDynamo:快速可靠地獲取圖
TorchDynamo 是一種使用 Frame Evaluation API (PEP-0523 中引入的一種 CPython 特性)的新方法。研發(fā)團(tuán)隊(duì)採用數(shù)據(jù)驅動(dòng)的方法來(lái)驗(yàn)證其在 Graph Capture 上的有效性,並使用 7000 多個(gè)用 PyTorch 編寫(xiě)的 Github 項(xiàng)目作為驗(yàn)證集。TorchScript 等方法大約在 50% 的時(shí)間裡都難以獲取圖,而且通常開(kāi)銷(xiāo)很大;而 TorchDynamo 在 99% 的時(shí)間裡都能獲取圖,方法正確、安全且開(kāi)銷(xiāo)可忽略不計(jì)(無(wú)需對(duì)原始代碼進(jìn)行任何更改)。這說(shuō)明 TorchDynamo 突破了多年來(lái)模型權(quán)衡靈活性和速度的瓶頸。
TorchInductor:使用 define-by-run IR 快速生成代碼
對(duì)於 PyTorch 2.0 的新編譯器後端,研發(fā)團(tuán)隊(duì)從用戶(hù)編寫(xiě)高性能自定義內(nèi)核的方式中汲取靈感:越來(lái)越多地使用 Triton 語言。此外,研究者還想要一個(gè)編譯器後端——使用與 PyTorch eager 類(lèi)似的抽象,並且具有足夠的通用性以支持 PyTorch 中廣(guǎng)泛的功能。
TorchInductor 使用 pythonic define-by-run loop level IR 自動(dòng)將 PyTorch 模型映射到 GPU 上生成的 Triton 代碼和 CPU 上的 C++/OpenMP。TorchInductor 的 core loop level IR 僅包含約 50 個(gè)算子,並且是用 Python 實(shí)現(xiàn)的,易於破解和擴展。
AOTAutograd:將 Autograd 重用於 ahead-of-time 圖
PyTorch 2.0 的主要特性之一是加速訓(xùn)練,因此 PyTorch 2.0 不僅要捕獲用戶(hù)級(jí)代碼,還要捕獲反向傳播。此外,研發(fā)團(tuán)隊(duì)還想要復用現(xiàn)有的經(jīng)過(guò)實(shí)踐檢驗(yàn)的 PyTorch autograd 系統(tǒng)。AOTAutograd 利用 PyTorch 的 torch_dispatch 可擴展機(jī)制來(lái)跟蹤 Autograd 引擎,使其能夠「ahead-of-time」捕獲反向傳遞(backwards pass)。這使 TorchInductor 能夠加速前向和反向傳遞。
PrimTorch:穩(wěn)定的原始算子
為 PyTorch 編寫(xiě)後端具有挑戰性。PyTorch 有 1200 多個(gè)算子,如果考慮每個(gè)算子的各種重載,則有 2000 多個(gè)。
在 PrimTorch 項(xiàng)目中,研發(fā)團(tuán)隊(duì)致力於定義更小且穩(wěn)定的算子集,將 PyTorch 程序縮減到這樣較小的算子集。目標(biāo)是定義兩(liǎng)個(gè)算子集:
Prim ops:約有 250 個(gè)相當(dāng)低級(jí)的算子。這些算子適用於編譯器,需要將它們重新融合在一起以獲得良好的性能;
ATen ops:約有 750 個(gè)規(guī)範算子。這些算子適用於已經(jīng)在 ATen 級(jí)別集成的後端或沒有編譯功能的後端(無(wú)法從較低級(jí)別的算子集(如 Prim ops)恢復性能)。