Author: Simon Fan ๋ฒ์ญ: ์ดํ์ค
.. grid:: 2
.. grid-item-card:: :octicon:`mortar-board;1em;` ๋ฌด์์ ๋ฐฐ์ธ ์ ์๋์?
:class-card: card-prerequisites
* Compiled Autograd๊ฐ ``torch.compile`` ์ ์ํธ์์ฉํ๋ ๋ฐฉ์
* Compiled Autograd API๋ฅผ ์ฌ์ฉํ๋ ๋ฐฉ๋ฒ
* ``TORCH_LOGS`` ๋ฅผ ์ฌ์ฉํ์ฌ ๋ก๊ทธ๋ฅผ ๊ฒ์ฌํ๋ ๋ฐฉ๋ฒ
.. grid-item-card:: :octicon:`list-unordered;1em;` ์ ์ ์กฐ๊ฑด
:class-card: card-prerequisites
* PyTorch 2.4
* `Introduction to torch.compile <https://tutorials.pytorch.kr/intermediate/torch_compile_tutorial.html>`_ ์๋ฃ
* `Get Started with PyTorch 2.x <https://pytorch.org/get-started/pytorch-2.0/>`_ ์ TorchDynamo์ AOTAutograd ๋ถ๋ถ์ ์ฝ์ด๋ณด์ธ์.
Compiled Autograd๋ PyTorch 2.4 ์์ ์๊ฐ๋ torch.compile ํ์ฅ ๊ธฐ๋ฅ์ผ๋ก, ๋ ํฐ backward ๊ทธ๋ํ๋ฅผ ์บก์ณํ ์ ์๊ฒ ํด์ค๋๋ค.
torch.compile ์ด backward ๊ทธ๋ํ๋ฅผ ํฌ์ํ๊ธด ํ์ง๋ง, ๊ทธ๊ฒ์ ๋ถ๋ถ์ ์ผ๋ก๋ง ์ด๋ฃจ์ด์ง๋๋ค. AOTAutograd ์ปดํฌ๋ํธ๋ backward ๊ทธ๋ํ๋ฅผ ์ฌ์ ์ ์บก์ณํ์ง๋ง, ๋ช ๊ฐ์ง ์ ํ์ฌํญ์ด ์กด์ฌํฉ๋๋ค.
- forward ์ฐ์ฐ์์ ๊ทธ๋ํ ๋ถ์ ์ด ์ผ์ด๋๋ฉด, backward ์ฐ์ฐ์์๋ ๊ทธ๋ํ๊ฐ ๋ถ์ ๋ฉ๋๋ค.
- Backward hooks ์ด ์บก์ณ๋์ง ์์ต๋๋ค
Compiled Autograd๋ ์ด๋ฌํ ์ ํ์ฌํญ์ ํด๊ฒฐํ๊ธฐ ์ํด autograd ์์ง๊ณผ ์ง์ ํตํฉ๋๋ฉฐ, ์คํ ์ ์ ์ฒด backward ๊ทธ๋ํ๋ฅผ ์บก์ณํ ์ ์์ต๋๋ค ์ด๋ฌํ ๋ ๊ฐ์ง ํน์ฑ์ ๊ฐ์ง ๋ชจ๋ธ์ ์ปดํ์ผ Autograd๋ฅผ ์๋ํด๋ณด๋ฉด ์ข์ผ๋ฉฐ, ์ ์ฌ์ ์ผ๋ก ๋ ์ข์ ์ฑ๋ฅ์ ์ป์ ์ ์์ต๋๋ค.
ํ์ง๋ง, Compiled Autograd์๋ ์ ํ ์ฌํญ์ด ์กด์ฌํฉ๋๋ค.
- backward ์์ ์ ์บ์๋ฅผ ํ์ธํ๊ธฐ ์ํด ๋ฐํ์ ์ค๋ฒํค๋๊ฐ ์ถ๊ฐ๋ฉ๋๋ค.
- ๋ ํฐ ์บก์ณ๋ฅผ ๋๋ฌธ์ dynamo ์์ ์ฌ์ปดํ์ผ๊ณผ ๊ทธ๋ํ ๋๊น์ด ๋ฐ์ํ๊ธฐ ์ฝ์ต๋๋ค.
Note
Compiled Autograd๋ ํ์ฌ ํ๋ฐํ ๊ฐ๋ฐ ์ค์ด๋ฉฐ, ์์ง ๊ธฐ์กด PyTorch ๊ธฐ๋ฅ๊ณผ ์์ ํ ํธํ๋์ง ์์ต๋๋ค. ํน์ ๊ธฐ๋ฅ์ ์ต์ ์ํ๋ Compiled Autograd Landing Page ๋ฅผ ์ฐธ๊ณ ํ์ธ์
์ด ํํ ๋ฆฌ์ผ์์, ๊ฐ๋จํ ์ ๊ฒฝ๋ง ๋ชจ๋ธ์ ๊ธฐ๋ฐ์ผ๋ก ์์ ๋ฅผ ์งํํฉ๋๋ค. 10์ฐจ์์ ์ ๋ ฅ ๋ฒกํฐ๋ฅผ ๋ฐ์, ๋จ์ผ ์ ํ ๋ ์ด์ด๋ฅผ ํต๊ณผ์ํจ ํ, ๋ ๋ค๋ฅธ 10์ฐจ์ ์ถ๋ ฅ ๋ฒกํฐ๋ฅผ ์์ฑํฉ๋๋ค.
import torch
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x):
return self.linear(x)torch.compile API๋ฅผ ํธ์ถ ํ๊ธฐ ์ ์, torch._dynamo.config.compiled_autograd ์ True ๋ก ์ค์ ํด์ฃผ์ธ์
model = Model()
x = torch.randn(10)
torch._dynamo.config.compiled_autograd = True
@torch.compile
def train(model, x):
loss = model(x).sum()
loss.backward()
train(model, x)์์ ์๋ ์ฝ๋๋, Model ํด๋์ค ์ธ์คํด์ค๋ฅผ ๋ง๋ค๊ณ torch.randn(10) ์ ์ฌ์ฉํ์ฌ ๋ฌด์์ 10์ฐจ์์ ํ
์ x ๋ฅผ ๋ง๋ญ๋๋ค.
ํ๋ จ ๋ฃจํ ํจ์ train ์ ์ ์ํ๊ณ , ์คํ ์ต์ ํ๋ฅผ ์ํด @torch.compile๋ก ์ง์ ํฉ๋๋ค.
train(model, x) ๊ฐ ํธ์ถ๋ ๋:
- Python ์ธํฐํ๋ฆฌํฐ๊ฐ Dynamo๋ฅผ ํธ์ถํฉ๋๋ค. ํด๋น ํธ์ถ์
@torch.compile๋ก ์ง์ ๋์๊ธฐ ๋๋ฌธ์ ๋๋ค. - Dynamo๋ Python ๋ฐ์ดํธ์ฝ๋๋ฅผ ๊ฐ๋ก์ฑ ์คํ์ ์๋ฎฌ๋ ์ด์ ํ๊ณ , ์ฐ์ฐ์ ๊ทธ๋ํ๋ก ๊ธฐ๋กํฉ๋๋ค.
AOTDispatcher์ ํ ์ ๋นํ์ฑํํ๊ณ autograd ์์ง์ ํธ์ถํ์ฌmodel.linear.weight์model.linear.bias์ ๋ณํ๋๋ฅผ ๊ณ์ฐํ๋ฉฐ, ์ฐ์ฐ์ ๊ทธ๋ํ๋ก ๊ธฐ๋กํฉ๋๋ค.torch.autograd.Function์ ์ฌ์ฉํ์ฌ, AOTDispatcher๋trainํจ์์ forward์ backward ๊ตฌํ์ ์ฌ์์ฑํฉ๋๋ค.- Inductor๋ AOTDispatch forward์ backward ๊ตฌํ์ ์ต์ ํํ ํจ์๋ฅผ ์์ฑํฉ๋๋ค.
- Dynamo๋ Python ์ธํฐํ๋ฆฌํฐ๊ฐ ๋ค์์ ํ๊ฐํ ์ต์ ํ๋ ํจ์๋ฅผ ์ค์ ํฉ๋๋ค.
- Python ์ธํฐํ๋ฆฌํฐ๋ ์ต์ ํ๋ ํจ์๋ฅผ ์คํํ๊ณ ,
loss = model(x).sum()์ ์คํํฉ๋๋ค. - Python ์ธํฐํ๋ฆฌํฐ๋
loss.backward()๋ฅผ ์คํํ๋ฉฐ ๋ด๋ถ autograd ์์ง์ ํธ์ถํ๊ณ ,torch._dynamo.config.compiled_autograd = True๋ก ์ค์ ํ๊ธฐ ๋๋ฌธ์, ํด๋น ํธ์ถ์ Compiled Autograd ์์ง์ผ๋ก ์ ๋ฌ๋ฉ๋๋ค. - Compiled Autograd ๋
model.linear.weight์model.linear.bias๋ณํ๋๋ฅผ ๊ณ์ฐํ๊ณ , ๋ง๋๋ ํ ์ ํฌํจํ์ฌ ์ฐ์ฐ์ ๊ทธ๋ํ๋ก ๊ธฐ๋กํฉ๋๋ค. ์ด ๊ณผ์ ์์ AOTDispatcher๊ฐ ์ด์ ์ ์ฌ์์ฑํ backward๋ ๊ธฐ๋ก๋ฉ๋๋ค. ๊ทธ ๋ค์ Compiled Autograd๋loss.backward()์ ์์ ํ ์ถ์ ๋ ๊ตฌํ์ ํด๋นํ๋ ์ ํจ์๋ฅผ ์์ฑํ๊ณ , ์ด๋ฅผ ์ถ๋ก ๋ชจ๋์์torch.compile๋ก ์คํํฉ๋๋ค. - ๋์ผํ ๋จ๊ณ์์ ์ฌ๊ท์ ์ผ๋ก Compiled Autograd ๊ทธ๋ํ์ ์ ์ฉ๋์ง๋ง, ์ด๋ฒ์๋ AOTDispatcher๊ฐ ๊ทธ๋ํ๋ฅผ ๋ถํ ํ ํ์๊ฐ ์์ต๋๋ค.
TORCH_LOGS ํ๊ฒฝ ๋ณ์๋ฅผ ์ค์ ํ์ฌ ์คํฌ๋ฆฝํธ๋ฅผ ์คํํฉ๋๋ค:
- Compiled Autograd ๊ทธ๋ํ๋ฅผ ์ถ๋ ฅํ๋ ค๋ฉด
TORCH_LOGS="compiled_autograd" python example.py์ ์ฌ์ฉํ์ธ์ - ๋ ๋ง์ ํ
์ ๋ฉํ๋ฐ์ดํฐ์ ์ฌ์ปดํ์ผ ์ด์ ๊น์ง ์ถ๋ ฅํ๊ณ ์ถ๋ค๋ฉด, ์ฑ๋ฅ์ด ์ ํ๋๋ ๋์
TORCH_LOGS="compiled_autograd_verbose" python example.py๋ฅผ ์ฌ์ฉํ์ธ์
์์ ์คํผ๋ท์ ๋ค์ ์คํํ๋ฉด, compiled autograd ๊ทธ๋ํ๊ฐ stderr ์ ๋ก๊น
์ด ๋ฉ๋๋ค.
์ผ๋ถ ๊ทธ๋ํ ๋
ธ๋๋ aot0_ ์ ๋์ฌ๊ฐ ๋ถ์ ์ด๋ฆ์ ๊ฐ์ง๋ฉฐ, ์ด๋ ์ด์ ์ AOTAutograd backward ๊ทธ๋ํ 0์์ ์ฌ์ ์ปดํ์ผ๋ ๋
ธ๋์ ํด๋นํฉ๋๋ค, ์๋ฅผ ๋ค์ด aot0_view_2 ๋ id=0์ธ AOT backward ๊ทธ๋ํ์ view_2 ์ ๋์๋ฉ๋๋ค.
์๋์ ์ด๋ฏธ์ง์์, ๋นจ๊ฐ ๋ฐ์ค๋ Compiled Autograd ์์ด torch.compile ๋ก ์บก์ณ๋ AOT backward ๊ทธ๋ํ๋ฅผ ๊ฐ์ธ๊ณ ์์ต๋๋ค.
Note
์ด ๊ทธ๋ํ๋ ์ฐ๋ฆฌ๊ฐ torch.compile ์ ํธ์ถํ ๋์์ด๋ฉฐ, ์ต์ ํ๋ ๊ทธ๋ํ๊ฐ ์๋๋๋ค. Compiled Autograd๋ ๊ธฐ๋ณธ์ ์ผ๋ก C++ autograd ์คํ์ ๋ํ๋ด๊ธฐ ์ํด ์ผ๋ถ ์ต์ ํ๋์ง ์์ Python ์ฝ๋๋ฅผ ์์ฑํฉ๋๋ค.
๋ ๊ฐ์ง์ ์ปดํ์ผ์ ๋ํด ์๋ก ๋ค๋ฅธ ์ปดํ์ผ๋ฌ ์ค์ ์ ์ฌ์ฉํ ์ ์์ต๋๋ค, ์๋ฅผ ๋ค์ด forward์ ๊ทธ๋ํ ๋ถ์ ์ด ์๋๋ผ๋ backward๋ fullgraph๋ก ์ค์ ํ ์ ์์ต๋๋ค.
def train(model, x):
model = torch.compile(model)
loss = model(x).sum()
torch._dynamo.config.compiled_autograd = True
torch.compile(lambda: loss.backward(), fullgraph=True)()๋๋ context manager๋ฅผ ์ฌ์ฉํ ์ ์์ผ๋ฉฐ, ํด๋น ์ค์ฝํ ์์ ๋ชจ๋ autograd ํธ์ถ์ ์ ์ฉ๋ ๊ฒ์ด๋ค.
def train(model, x):
model = torch.compile(model)
loss = model(x).sum()
with torch._dynamo.compiled_autograd.enable(torch.compile(fullgraph=True)):
loss.backward()- forward ํจ์ค์ ๊ทธ๋ํ ๋ถ์ ์ ๋ ์ด์ backward ํจ์ค์ ๊ทธ๋ํ ๋ถ์ ๋ก ์ด์ด์ง์ง ์์ต๋๋ค.
@torch.compile(backend="aot_eager")
def fn(x):
# 1st graph
temp = x + 10
torch._dynamo.graph_break()
# 2nd graph
temp = temp + 10
torch._dynamo.graph_break()
# 3rd graph
return temp.sum()
x = torch.randn(10, 10, requires_grad=True)
torch._dynamo.utils.counters.clear()
loss = fn(x)
# 1. base torch.compile
loss.backward(retain_graph=True)
assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 3)
torch._dynamo.utils.counters.clear()
# 2. torch.compile with compiled autograd
with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
loss.backward()
# single graph for the backward
assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 1)์ฒซ ๋ฒ์งธ torch.compile ์ ๊ฒฝ์ฐ์๋, ์ปดํ์ผ๋ ํจ์ fn ์์ 2๊ฐ์ ๊ทธ๋ํ ๋ถ์ ๋ก ์ธํด 3๊ฐ์ backward ๊ทธ๋ํ๊ฐ ์์ฑ๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
๋ฐ๋ฉด, Compiled Autograd๋ฅผ ์ฌ์ฉํ ๋ ๋ฒ์งธ torch.compile ๊ฒฝ์ฐ์๋ ๊ทธ๋ํ ๋ถ์ ์ด ์๋๋ผ๋ ์ ์ฒด backward ๊ทธ๋ํ๊ฐ ํธ๋ ์ด์ค๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
Note
Compiled Autograd๊ฐ ์บก์ณํ backward ํ ์ ํธ๋ ์ด์คํ ๋, Dynamo์์ ๊ทธ๋ํ๊ฐ ๋ถ์ ๋ ๊ฐ๋ฅ์ฑ์ ์ฌ์ ํ ์กด์ฌํฉ๋๋ค.
- Backward ํ ์ ์บก์ณ๋ ์ ์์ต๋๋ค.
@torch.compile(backend="aot_eager")
def fn(x):
return x.sum()
x = torch.randn(10, 10, requires_grad=True)
x.register_hook(lambda grad: grad+10)
loss = fn(x)
with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
loss.backward()๊ทธ๋ํ์๋ call_hook ๋
ธ๋๊ฐ ์์ด์ผ ํ๋ฉฐ, ์ดํ dynamo๋ ์ด๋ฅผ ๋ค์๊ณผ ๊ฐ์ด ์ธ๋ผ์ธ ์ฒ๋ฆฌํฉ๋๋ค.
- ์์ค ๊ฐ์ autograd ๊ตฌ์กฐ๊ฐ ๋ณ๊ฒฝ๋์๊ธฐ ๋๋ฌธ์ ๋๋ค.
torch._dynamo.config.compiled_autograd = True
x = torch.randn(10, requires_grad=True)
for op in [torch.add, torch.sub, torch.mul, torch.div]:
loss = op(x, x).sum()
torch.compile(lambda: loss.backward(), backend="eager")()์์ ์์ ์์, ๊ฐ ๋ฐ๋ณต๋ง๋ค ๋ค๋ฅธ ์ฐ์ฐ์ ํธ์ถํ์ฌ loss ๊ฐ ๋งค๋ฒ ๋ค๋ฅธ autograd ๊ธฐ๋ก์ ์ถ์ ํฉ๋๋ค. ์ด๋ก ์ธํด ์ฌ์ปดํ์ผ ๋ฉ์์ง๊ฐ (Cache miss due to new autograd node) ํ์๋๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
- ํ ์์ ํํ๊ฐ ๋ณ๊ฒฝ๋์๊ธฐ ๋๋ฌธ์ ๋๋ค.
torch._dynamo.config.compiled_autograd = True
for i in [10, 100, 10]:
x = torch.randn(i, i, requires_grad=True)
loss = x.sum()
torch.compile(lambda: loss.backward(), backend="eager")()์์ ์์ ์์, x ์ ํํ๊ฐ ๋ณ๊ฒฝ๋๋ฉด, compiled autograd๋ ์ฒซ ๋ฒ์งธ ๋ณ๊ฒฝ ์ดํ x ๋ฅผ ๋์ ํํ ํ
์๋ก ํ์ํฉ๋๋ค. ์ด๋ก ์ธํด ์ฌ์ปดํ์ผ ๋ฉ์์ง๊ฐ (Cache miss due to changed shapes) ๋ํ๋๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
์ด ํํ ๋ฆฌ์ผ์์๋, torch.compile ๊ณผ compiled autograd์ ๊ณ ์ฐจ์ ์ํ๊ณ, compiled autograd์ ๊ธฐ์ด์ ๋ช ๊ฐ์ง์ ๊ณตํต์ ์ธ ์ฌ์ปดํ์ผ ์ด์ ๋ฅผ ์ดํด๋ณด์์ต๋๋ค. ์์ธํ ๋ด์ฉ์ dev-discuss ์์ ํ์ธํ ์ ์์ต๋๋ค.



