Skip to content

Latest commit

ย 

History

History
219 lines (153 loc) ยท 11.4 KB

File metadata and controls

219 lines (153 loc) ยท 11.4 KB

Compiled Autograd: torch.compile ์„ ์œ„ํ•ด ๋” ํฐ backward ๊ทธ๋ž˜ํ”„๋ฅผ ํฌ์ฐฉํ•˜๊ธฐ

Author: Simon Fan ๋ฒˆ์—ญ: ์ดํ˜„์ค€

.. grid:: 2

    .. grid-item-card:: :octicon:`mortar-board;1em;` ๋ฌด์—‡์„ ๋ฐฐ์šธ ์ˆ˜ ์žˆ๋‚˜์š”?
       :class-card: card-prerequisites

       * Compiled Autograd๊ฐ€ ``torch.compile`` ์™€ ์ƒํ˜ธ์ž‘์šฉํ•˜๋Š” ๋ฐฉ์‹
       * Compiled Autograd API๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•
       * ``TORCH_LOGS`` ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋กœ๊ทธ๋ฅผ ๊ฒ€์‚ฌํ•˜๋Š” ๋ฐฉ๋ฒ•

    .. grid-item-card:: :octicon:`list-unordered;1em;` ์ „์ œ ์กฐ๊ฑด
       :class-card: card-prerequisites

       * PyTorch 2.4
       * `Introduction to torch.compile <https://tutorials.pytorch.kr/intermediate/torch_compile_tutorial.html>`_ ์™„๋ฃŒ
       * `Get Started with PyTorch 2.x <https://pytorch.org/get-started/pytorch-2.0/>`_ ์˜ TorchDynamo์™€ AOTAutograd ๋ถ€๋ถ„์„ ์ฝ์–ด๋ณด์„ธ์š”.

๊ฐœ์š”

Compiled Autograd๋Š” PyTorch 2.4 ์—์„œ ์†Œ๊ฐœ๋œ torch.compile ํ™•์žฅ ๊ธฐ๋Šฅ์œผ๋กœ, ๋” ํฐ backward ๊ทธ๋ž˜ํ”„๋ฅผ ์บก์ณํ•  ์ˆ˜ ์žˆ๊ฒŒ ํ•ด์ค๋‹ˆ๋‹ค.

torch.compile ์ด backward ๊ทธ๋ž˜ํ”„๋ฅผ ํฌ์ž‘ํ•˜๊ธด ํ•˜์ง€๋งŒ, ๊ทธ๊ฒƒ์€ ๋ถ€๋ถ„์ ์œผ๋กœ๋งŒ ์ด๋ฃจ์–ด์ง‘๋‹ˆ๋‹ค. AOTAutograd ์ปดํฌ๋„ŒํŠธ๋Š” backward ๊ทธ๋ž˜ํ”„๋ฅผ ์‚ฌ์ „์— ์บก์ณํ•˜์ง€๋งŒ, ๋ช‡ ๊ฐ€์ง€ ์ œํ•œ์‚ฌํ•ญ์ด ์กด์žฌํ•ฉ๋‹ˆ๋‹ค.

  • forward ์—ฐ์‚ฐ์—์„œ ๊ทธ๋ž˜ํ”„ ๋ถ„์ ˆ์ด ์ผ์–ด๋‚˜๋ฉด, backward ์—ฐ์‚ฐ์—์„œ๋„ ๊ทธ๋ž˜ํ”„๊ฐ€ ๋ถ„์ ˆ๋ฉ๋‹ˆ๋‹ค.
  • Backward hooks ์ด ์บก์ณ๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค

Compiled Autograd๋Š” ์ด๋Ÿฌํ•œ ์ œํ•œ์‚ฌํ•ญ์„ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด autograd ์—”์ง„๊ณผ ์ง์ ‘ ํ†ตํ•ฉ๋˜๋ฉฐ, ์‹คํ–‰ ์‹œ ์ „์ฒด backward ๊ทธ๋ž˜ํ”„๋ฅผ ์บก์ณํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค ์ด๋Ÿฌํ•œ ๋‘ ๊ฐ€์ง€ ํŠน์„ฑ์„ ๊ฐ€์ง„ ๋ชจ๋ธ์€ ์ปดํŒŒ์ผ Autograd๋ฅผ ์‹œ๋„ํ•ด๋ณด๋ฉด ์ข‹์œผ๋ฉฐ, ์ž ์žฌ์ ์œผ๋กœ ๋” ์ข‹์€ ์„ฑ๋Šฅ์„ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

ํ•˜์ง€๋งŒ, Compiled Autograd์—๋„ ์ œํ•œ ์‚ฌํ•ญ์ด ์กด์žฌํ•ฉ๋‹ˆ๋‹ค.

  • backward ์‹œ์ž‘ ์‹œ ์บ์‹œ๋ฅผ ํ™•์ธํ•˜๊ธฐ ์œ„ํ•ด ๋Ÿฐํƒ€์ž„ ์˜ค๋ฒ„ํ—ค๋“œ๊ฐ€ ์ถ”๊ฐ€๋ฉ๋‹ˆ๋‹ค.
  • ๋” ํฐ ์บก์ณ๋ฅผ ๋•Œ๋ฌธ์— dynamo ์—์„œ ์žฌ์ปดํŒŒ์ผ๊ณผ ๊ทธ๋ž˜ํ”„ ๋Š๊น€์ด ๋ฐœ์ƒํ•˜๊ธฐ ์‰ฝ์Šต๋‹ˆ๋‹ค.

Note

Compiled Autograd๋Š” ํ˜„์žฌ ํ™œ๋ฐœํžˆ ๊ฐœ๋ฐœ ์ค‘์ด๋ฉฐ, ์•„์ง ๊ธฐ์กด PyTorch ๊ธฐ๋Šฅ๊ณผ ์™„์ „ํžˆ ํ˜ธํ™˜๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ํŠน์ • ๊ธฐ๋Šฅ์˜ ์ตœ์‹  ์ƒํƒœ๋Š” Compiled Autograd Landing Page ๋ฅผ ์ฐธ๊ณ ํ•˜์„ธ์š”

์„ค์ •

์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ, ๊ฐ„๋‹จํ•œ ์‹ ๊ฒฝ๋ง ๋ชจ๋ธ์„ ๊ธฐ๋ฐ˜์œผ๋กœ ์˜ˆ์ œ๋ฅผ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค. 10์ฐจ์›์˜ ์ž…๋ ฅ ๋ฒกํ„ฐ๋ฅผ ๋ฐ›์•„, ๋‹จ์ผ ์„ ํ˜• ๋ ˆ์ด์–ด๋ฅผ ํ†ต๊ณผ์‹œํ‚จ ํ›„, ๋˜ ๋‹ค๋ฅธ 10์ฐจ์› ์ถœ๋ ฅ ๋ฒกํ„ฐ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.

import torch

class Model(torch.nn.Module):
   def __init__(self):
      super().__init__()
      self.linear = torch.nn.Linear(10, 10)

   def forward(self, x):
      return self.linear(x)

๊ธฐ๋ณธ ์‚ฌ์šฉ

torch.compile API๋ฅผ ํ˜ธ์ถœ ํ•˜๊ธฐ ์ „์—, torch._dynamo.config.compiled_autograd ์„ True ๋กœ ์„ค์ •ํ•ด์ฃผ์„ธ์š”

model = Model()
x = torch.randn(10)

torch._dynamo.config.compiled_autograd = True
@torch.compile
def train(model, x):
   loss = model(x).sum()
   loss.backward()

train(model, x)

์œ„์— ์žˆ๋Š” ์ฝ”๋“œ๋Š”, Model ํด๋ž˜์Šค ์ธ์Šคํ„ด์Šค๋ฅผ ๋งŒ๋“ค๊ณ  torch.randn(10) ์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ฌด์ž‘์œ„ 10์ฐจ์›์˜ ํ…์„œ x ๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค. ํ›ˆ๋ จ ๋ฃจํ”„ ํ•จ์ˆ˜ train ์„ ์ •์˜ํ•˜๊ณ , ์‹คํ–‰ ์ตœ์ ํ™”๋ฅผ ์œ„ํ•ด @torch.compile๋กœ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค. train(model, x) ๊ฐ€ ํ˜ธ์ถœ๋  ๋•Œ:

  • Python ์ธํ„ฐํ”„๋ฆฌํ„ฐ๊ฐ€ Dynamo๋ฅผ ํ˜ธ์ถœํ•ฉ๋‹ˆ๋‹ค. ํ•ด๋‹น ํ˜ธ์ถœ์€ @torch.compile ๋กœ ์ง€์ •๋˜์—ˆ๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค.
  • Dynamo๋Š” Python ๋ฐ”์ดํŠธ์ฝ”๋“œ๋ฅผ ๊ฐ€๋กœ์ฑ„ ์‹คํ–‰์„ ์‹œ๋ฎฌ๋ ˆ์ด์…˜ํ•˜๊ณ , ์—ฐ์‚ฐ์„ ๊ทธ๋ž˜ํ”„๋กœ ๊ธฐ๋กํ•ฉ๋‹ˆ๋‹ค.
  • AOTDispatcher ์€ ํ›…์„ ๋น„ํ™œ์„ฑํ™”ํ•˜๊ณ  autograd ์—”์ง„์„ ํ˜ธ์ถœํ•˜์—ฌ model.linear.weight ์™€ model.linear.bias ์˜ ๋ณ€ํ™”๋„๋ฅผ ๊ณ„์‚ฐํ•˜๋ฉฐ, ์—ฐ์‚ฐ์„ ๊ทธ๋ž˜ํ”„๋กœ ๊ธฐ๋กํ•ฉ๋‹ˆ๋‹ค. torch.autograd.Function ์„ ์‚ฌ์šฉํ•˜์—ฌ, AOTDispatcher๋Š” train ํ•จ์ˆ˜์˜ forward์™€ backward ๊ตฌํ˜„์„ ์žฌ์ž‘์„ฑํ•ฉ๋‹ˆ๋‹ค.
  • Inductor๋Š” AOTDispatch forward์™€ backward ๊ตฌํ˜„์„ ์ตœ์ ํ™”ํ•œ ํ•จ์ˆ˜๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
  • Dynamo๋Š” Python ์ธํ„ฐํ”„๋ฆฌํ„ฐ๊ฐ€ ๋‹ค์Œ์— ํ‰๊ฐ€ํ•  ์ตœ์ ํ™”๋œ ํ•จ์ˆ˜๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
  • Python ์ธํ„ฐํ”„๋ฆฌํ„ฐ๋Š” ์ตœ์ ํ™”๋œ ํ•จ์ˆ˜๋ฅผ ์‹คํ–‰ํ•˜๊ณ , loss = model(x).sum() ์„ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค.
  • Python ์ธํ„ฐํ”„๋ฆฌํ„ฐ๋Š” loss.backward() ๋ฅผ ์‹คํ–‰ํ•˜๋ฉฐ ๋‚ด๋ถ€ autograd ์—”์ง„์„ ํ˜ธ์ถœํ•˜๊ณ , torch._dynamo.config.compiled_autograd = True ๋กœ ์„ค์ •ํ–ˆ๊ธฐ ๋•Œ๋ฌธ์—, ํ•ด๋‹น ํ˜ธ์ถœ์€ Compiled Autograd ์—”์ง„์œผ๋กœ ์ „๋‹ฌ๋ฉ๋‹ˆ๋‹ค.
  • Compiled Autograd ๋Š” model.linear.weight ์™€ model.linear.bias ๋ณ€ํ™”๋„๋ฅผ ๊ณ„์‚ฐํ•˜๊ณ , ๋งŒ๋‚˜๋Š” ํ›…์„ ํฌํ•จํ•˜์—ฌ ์—ฐ์‚ฐ์„ ๊ทธ๋ž˜ํ”„๋กœ ๊ธฐ๋กํ•ฉ๋‹ˆ๋‹ค. ์ด ๊ณผ์ •์—์„œ AOTDispatcher๊ฐ€ ์ด์ „์— ์žฌ์ž‘์„ฑํ•œ backward๋„ ๊ธฐ๋ก๋ฉ๋‹ˆ๋‹ค. ๊ทธ ๋‹ค์Œ Compiled Autograd๋Š” loss.backward() ์˜ ์™„์ „ํžˆ ์ถ”์ ๋œ ๊ตฌํ˜„์— ํ•ด๋‹นํ•˜๋Š” ์ƒˆ ํ•จ์ˆ˜๋ฅผ ์ƒ์„ฑํ•˜๊ณ , ์ด๋ฅผ ์ถ”๋ก  ๋ชจ๋“œ์—์„œ torch.compile ๋กœ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค.
  • ๋™์ผํ•œ ๋‹จ๊ณ„์—์„œ ์žฌ๊ท€์ ์œผ๋กœ Compiled Autograd ๊ทธ๋ž˜ํ”„์— ์ ์šฉ๋˜์ง€๋งŒ, ์ด๋ฒˆ์—๋Š” AOTDispatcher๊ฐ€ ๊ทธ๋ž˜ํ”„๋ฅผ ๋ถ„ํ• ํ•  ํ•„์š”๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.

Compiled Autograd ๋กœ๊ทธ ๊ฒ€์‚ฌ

TORCH_LOGS ํ™˜๊ฒฝ ๋ณ€์ˆ˜๋ฅผ ์„ค์ •ํ•˜์—ฌ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค:

  • Compiled Autograd ๊ทธ๋ž˜ํ”„๋ฅผ ์ถœ๋ ฅํ•˜๋ ค๋ฉด TORCH_LOGS="compiled_autograd" python example.py ์„ ์‚ฌ์šฉํ•˜์„ธ์š”
  • ๋” ๋งŽ์€ ํ…์„œ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ์™€ ์žฌ์ปดํŒŒ์ผ ์ด์œ ๊นŒ์ง€ ์ถœ๋ ฅํ•˜๊ณ  ์‹ถ๋‹ค๋ฉด, ์„ฑ๋Šฅ์ด ์ €ํ•˜๋˜๋Š” ๋Œ€์‹  TORCH_LOGS="compiled_autograd_verbose" python example.py ๋ฅผ ์‚ฌ์šฉํ•˜์„ธ์š”

์œ„์˜ ์Šคํ”ผ๋„ท์„ ๋‹ค์‹œ ์‹คํ–‰ํ•˜๋ฉด, compiled autograd ๊ทธ๋ž˜ํ”„๊ฐ€ stderr ์— ๋กœ๊น…์ด ๋ฉ๋‹ˆ๋‹ค. ์ผ๋ถ€ ๊ทธ๋ž˜ํ”„ ๋…ธ๋“œ๋Š” aot0_ ์ ‘๋‘์‚ฌ๊ฐ€ ๋ถ™์€ ์ด๋ฆ„์„ ๊ฐ€์ง€๋ฉฐ, ์ด๋Š” ์ด์ „์— AOTAutograd backward ๊ทธ๋ž˜ํ”„ 0์—์„œ ์‚ฌ์ „ ์ปดํŒŒ์ผ๋œ ๋…ธ๋“œ์— ํ•ด๋‹นํ•ฉ๋‹ˆ๋‹ค, ์˜ˆ๋ฅผ ๋“ค์–ด aot0_view_2 ๋Š” id=0์ธ AOT backward ๊ทธ๋ž˜ํ”„์˜ view_2 ์— ๋Œ€์‘๋ฉ๋‹ˆ๋‹ค. ์•„๋ž˜์˜ ์ด๋ฏธ์ง€์—์„œ, ๋นจ๊ฐ„ ๋ฐ•์Šค๋Š” Compiled Autograd ์—†์ด torch.compile ๋กœ ์บก์ณ๋œ AOT backward ๊ทธ๋ž˜ํ”„๋ฅผ ๊ฐ์‹ธ๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

../_static/img/compiled_autograd/entire_verbose_log.png

Note

์ด ๊ทธ๋ž˜ํ”„๋Š” ์šฐ๋ฆฌ๊ฐ€ torch.compile ์„ ํ˜ธ์ถœํ•  ๋Œ€์ƒ์ด๋ฉฐ, ์ตœ์ ํ™”๋œ ๊ทธ๋ž˜ํ”„๊ฐ€ ์•„๋‹™๋‹ˆ๋‹ค. Compiled Autograd๋Š” ๊ธฐ๋ณธ์ ์œผ๋กœ C++ autograd ์‹คํ–‰์„ ๋‚˜ํƒ€๋‚ด๊ธฐ ์œ„ํ•ด ์ผ๋ถ€ ์ตœ์ ํ™”๋˜์ง€ ์•Š์€ Python ์ฝ”๋“œ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.

๋‹ค๋ฅธ ํ”Œ๋ž˜๊ทธ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ forward์™€ backward ํŒจ์Šค๋ฅผ ์ปดํŒŒ์ผํ•˜๊ธฐ

๋‘ ๊ฐ€์ง€์˜ ์ปดํŒŒ์ผ์— ๋Œ€ํ•ด ์„œ๋กœ ๋‹ค๋ฅธ ์ปดํŒŒ์ผ๋Ÿฌ ์„ค์ •์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค, ์˜ˆ๋ฅผ ๋“ค์–ด forward์— ๊ทธ๋ž˜ํ”„ ๋ถ„์ ˆ์ด ์žˆ๋”๋ผ๋„ backward๋Š” fullgraph๋กœ ์„ค์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

def train(model, x):
    model = torch.compile(model)
    loss = model(x).sum()
    torch._dynamo.config.compiled_autograd = True
    torch.compile(lambda: loss.backward(), fullgraph=True)()

๋˜๋Š” context manager๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, ํ•ด๋‹น ์Šค์ฝ”ํ”„ ์•ˆ์˜ ๋ชจ๋“  autograd ํ˜ธ์ถœ์— ์ ์šฉ๋  ๊ฒƒ์ด๋‹ค.

def train(model, x):
   model = torch.compile(model)
   loss = model(x).sum()
   with torch._dynamo.compiled_autograd.enable(torch.compile(fullgraph=True)):
      loss.backward()

Compiled Autograd๋Š” AOTAutograd์˜ ํŠน์ • ํ•œ๊ณ„์ ์„ ํ•ด๊ฒฐํ•ฉ๋‹ˆ๋‹ค.

  1. forward ํŒจ์Šค์˜ ๊ทธ๋ž˜ํ”„ ๋ถ„์ ˆ์€ ๋” ์ด์ƒ backward ํŒจ์Šค์˜ ๊ทธ๋ž˜ํ”„ ๋ถ„์ ˆ๋กœ ์ด์–ด์ง€์ง€ ์•Š์Šต๋‹ˆ๋‹ค.
@torch.compile(backend="aot_eager")
def fn(x):
   # 1st graph
   temp = x + 10
   torch._dynamo.graph_break()
   # 2nd graph
   temp = temp + 10
   torch._dynamo.graph_break()
   # 3rd graph
   return temp.sum()

x = torch.randn(10, 10, requires_grad=True)
torch._dynamo.utils.counters.clear()
loss = fn(x)

# 1. base torch.compile
loss.backward(retain_graph=True)
assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 3)
torch._dynamo.utils.counters.clear()

# 2. torch.compile with compiled autograd
with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
   loss.backward()

# single graph for the backward
assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 1)

์ฒซ ๋ฒˆ์งธ torch.compile ์˜ ๊ฒฝ์šฐ์—๋Š”, ์ปดํŒŒ์ผ๋œ ํ•จ์ˆ˜ fn ์—์„œ 2๊ฐœ์˜ ๊ทธ๋ž˜ํ”„ ๋ถ„์ ˆ๋กœ ์ธํ•ด 3๊ฐœ์˜ backward ๊ทธ๋ž˜ํ”„๊ฐ€ ์ƒ์„ฑ๋œ ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ฐ˜๋ฉด, Compiled Autograd๋ฅผ ์‚ฌ์šฉํ•œ ๋‘ ๋ฒˆ์งธ torch.compile ๊ฒฝ์šฐ์—๋Š” ๊ทธ๋ž˜ํ”„ ๋ถ„์ ˆ์ด ์žˆ๋”๋ผ๋„ ์ „์ฒด backward ๊ทธ๋ž˜ํ”„๊ฐ€ ํŠธ๋ ˆ์ด์Šค๋œ ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Note

Compiled Autograd๊ฐ€ ์บก์ณํ•œ backward ํ›…์„ ํŠธ๋ ˆ์ด์Šคํ•  ๋•Œ, Dynamo์—์„œ ๊ทธ๋ž˜ํ”„๊ฐ€ ๋ถ„์ ˆ๋  ๊ฐ€๋Šฅ์„ฑ์€ ์—ฌ์ „ํžˆ ์กด์žฌํ•ฉ๋‹ˆ๋‹ค.

  1. Backward ํ›…์€ ์บก์ณ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
@torch.compile(backend="aot_eager")
def fn(x):
   return x.sum()

x = torch.randn(10, 10, requires_grad=True)
x.register_hook(lambda grad: grad+10)
loss = fn(x)

with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
   loss.backward()

๊ทธ๋ž˜ํ”„์—๋Š” call_hook ๋…ธ๋“œ๊ฐ€ ์žˆ์–ด์•ผ ํ•˜๋ฉฐ, ์ดํ›„ dynamo๋Š” ์ด๋ฅผ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ธ๋ผ์ธ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค.

../_static/img/compiled_autograd/call_hook_node.png

Compiled Autograd์˜ ๊ณตํ†ต์ ์ธ ์žฌ์ปดํŒŒ์ผ ์ด์œ 

  1. ์†์‹ค ๊ฐ’์˜ autograd ๊ตฌ์กฐ๊ฐ€ ๋ณ€๊ฒฝ๋˜์—ˆ๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค.
torch._dynamo.config.compiled_autograd = True
x = torch.randn(10, requires_grad=True)
for op in [torch.add, torch.sub, torch.mul, torch.div]:
   loss = op(x, x).sum()
   torch.compile(lambda: loss.backward(), backend="eager")()

์œ„์˜ ์˜ˆ์ œ์—์„œ, ๊ฐ ๋ฐ˜๋ณต๋งˆ๋‹ค ๋‹ค๋ฅธ ์—ฐ์‚ฐ์„ ํ˜ธ์ถœํ•˜์—ฌ loss ๊ฐ€ ๋งค๋ฒˆ ๋‹ค๋ฅธ autograd ๊ธฐ๋ก์„ ์ถ”์ ํ•ฉ๋‹ˆ๋‹ค. ์ด๋กœ ์ธํ•ด ์žฌ์ปดํŒŒ์ผ ๋ฉ”์‹œ์ง€๊ฐ€ (Cache miss due to new autograd node) ํ‘œ์‹œ๋˜๋Š” ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

../_static/img/compiled_autograd/recompile_due_to_node.png

  1. ํ…์„œ์˜ ํ˜•ํƒœ๊ฐ€ ๋ณ€๊ฒฝ๋˜์—ˆ๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค.
torch._dynamo.config.compiled_autograd = True
for i in [10, 100, 10]:
   x = torch.randn(i, i, requires_grad=True)
   loss = x.sum()
   torch.compile(lambda: loss.backward(), backend="eager")()

์œ„์˜ ์˜ˆ์ œ์—์„œ, x ์˜ ํ˜•ํƒœ๊ฐ€ ๋ณ€๊ฒฝ๋˜๋ฉด, compiled autograd๋Š” ์ฒซ ๋ฒˆ์งธ ๋ณ€๊ฒฝ ์ดํ›„ x ๋ฅผ ๋™์  ํ˜•ํƒœ ํ…์„œ๋กœ ํ‘œ์‹œํ•ฉ๋‹ˆ๋‹ค. ์ด๋กœ ์ธํ•ด ์žฌ์ปดํŒŒ์ผ ๋ฉ”์‹œ์ง€๊ฐ€ (Cache miss due to changed shapes) ๋‚˜ํƒ€๋‚˜๋Š” ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

../_static/img/compiled_autograd/recompile_due_to_dynamic.png

๊ฒฐ๋ก 

์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ๋Š”, torch.compile ๊ณผ compiled autograd์˜ ๊ณ ์ฐจ์› ์ƒํƒœ๊ณ„, compiled autograd์˜ ๊ธฐ์ดˆ์™€ ๋ช‡ ๊ฐ€์ง€์˜ ๊ณตํ†ต์ ์ธ ์žฌ์ปดํŒŒ์ผ ์ด์œ ๋ฅผ ์‚ดํŽด๋ณด์•˜์Šต๋‹ˆ๋‹ค. ์ž์„ธํ•œ ๋‚ด์šฉ์€ dev-discuss ์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.