-
Notifications
You must be signed in to change notification settings - Fork 285
Expand file tree
/
Copy pathper_sample_grads.py
More file actions
218 lines (167 loc) Β· 9.48 KB
/
per_sample_grads.py
File metadata and controls
218 lines (167 loc) Β· 9.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
# -*- coding: utf-8 -*-
"""
μνλ³ λ³νλ(Per-sample-gradients)
====================
**λ²μ:** `μ΅λμ€ <https://github.com/justjs4evr>`__
무μμΌκΉμ?
-----------
μνλ³ λ³νλ κ³μ°μ λ°μ΄ν° ν λ°°μΉμ λͺ¨λ μνμ λν΄
λ³νλλ₯Ό κ³μ°νλ κ²μ
λλ€. μ°¨λΆ νλΌμ΄λ²μ(differential privacy), λ©ν νμ΅, μ΅μ ν μ°κ΅¬μμ
μ μ©νκ² μ¬μ©λ©λλ€.
.. note::
μ΄ νν 리μΌμ Pytorch 2.0.0 μ΄μμ λ²μ μ νμλ‘ ν©λλ€.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)
# λ€μμ κ°λ¨ν CNN λͺ¨λΈκ³Ό μμ€ ν¨μμ
λλ€.
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
def loss_fn(predictions, targets):
return F.nll_loss(predictions, targets)
######################################################################
# MNIST λ°μ΄ν°μ
μ μ¬μ©νλ€κ³ κ°μ νκ³ , λλ―Έ(dummy) λ°μ΄ν° λ°°μΉλ₯Ό λ§λ€μμ΅λλ€.
# κ° λ°μ΄ν°λ 28x28μ ν¬κΈ°λ₯Ό κ°μ§κ³ , λ―Έλλ°°μΉμ ν¬κΈ°λ₯Ό 64λ‘ λμ΅λλ€.
device = 'cuda'
num_models = 10
batch_size = 64
data = torch.randn(batch_size, 1, 28, 28, device=device)
targets = torch.randint(10, (64,), device=device)
######################################################################
# λ―Έλλ°°μΉλ₯Ό λͺ¨λΈμ μ£Όκ³ , .backward()λ₯Ό νΈμΆν΄μ
# λ³νλλ₯Ό κ³μ°νλ κ²μ΄ μΌλ°μ μΈ λͺ¨λΈ νμ΅ λ°©λ²μ
λλ€.
# κ³§, κ·Έ λ―Έλλ°°μΉ μ 체μ 'νκ· μ μΈ' λ³νλλ₯Ό μμ±ν μ μμ΅λλ€:
model = SimpleCNN().to(device=device)
predictions = model(data) # μ 체 λ―Έλλ°°μΉλ₯Ό λͺ¨λΈμ ν΅κ³Όμν΅λλ€.
loss = loss_fn(predictions, targets)
loss.backward() # λ―Έλλ°°μΉμ νκ· λ³νλλ₯Ό μμ νμν΅λλ€.
######################################################################
# μ λ°©λ²κ³Όλ λ°λλ‘, μνλ³ λ³νλ κ³μ°μ λ€μκ³Ό κ°μ΅λλ€:
#
# - λ°μ΄ν°μ κ° μνμ λν΄, μμ νμ μμ νλ₯Ό μννμ¬
# κ°λ³μ μΈ(μνλ³) λ³νλλ₯Ό ꡬν©λλ€.
def compute_grad(sample, target):
sample = sample.unsqueeze(0) # μ²λ¦¬λ₯Ό μν΄μ λ°°μΉ μ°¨μμ μΆκ°ν΄μΌ ν©λλ€.
target = target.unsqueeze(0)
prediction = model(sample)
loss = loss_fn(prediction, target)
return torch.autograd.grad(loss, list(model.parameters()))
def compute_sample_grads(data, targets):
""" κ° μνμ λν λ³νλλ₯Ό μ§μ κ³μ°ν©λλ€ """
sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)]
sample_grads = zip(*sample_grads)
sample_grads = [torch.stack(shards) for shards in sample_grads]
return sample_grads
per_sample_grads = compute_sample_grads(data, targets)
######################################################################
# ``sample_grads[0]`` λ ``model.conv1.weight`` μ λν μνλ³ λ³νλμ
λλ€.
# ``model.conv1.weight.shape`` λ ``[32, 1, 3, 3]`` μ
λλ€. λ°°μΉ λ΄λΆμ μνλΉ νλμ©,
# μ΄ 64κ°μ λ³νλκ° μλ€λ μ μ μ£Όλͺ©νμκΈΈ λ°λλλ€.
print(per_sample_grads[0].shape)
######################################################################
# ν¨μ λ³ν(Function Transforms)μ μ΄μ©ν *ν¨μ¨μ μΈ* μνλ³ λ³νλ κ³μ°
# ----------------------------------------------------------------
# ν¨μ λ³νμ μ¬μ©νλ©΄ λ ν¨μ¨μ μΌλ‘ μνλ³ λ³νλλ₯Ό μ μ μμ΅λλ€.
#
# ``torch.func`` μ ν¨μ λ³ν APIλ ν¨μλ₯Ό λμμΌλ‘ λμν©λλ€.
# λ¨Όμ μμ€ ν¨μλ₯Ό μ μνκ³ , λ³νν΄μ
# μνλ³ λ³νλλ₯Ό κ³μ°νλ ν¨μλ₯Ό νμ±νλ κ²μ΄ μ λ΅μ
λλ€.
#
# ``torch.func.functional_call`` ν¨μλ₯Ό κ°μ§κ³ ``nn.Module`` μ ν¨μμ²λΌ λ€λ£° κ²μ
λλ€.
#
# μ°μ , ``model`` μ μνλ₯Ό λ§€κ°λ³μμ λ²νΌ, λ λμ
λλ¦¬λ‘ μΆμΆν΄μΌ ν©λλ€.
# μΌλ°μ μΈ PyTorch autograd(μ: Tensor.backward(), torch.autograd.grad)λ₯Ό
# μ¬μ©νμ§ μμ κ²μ΄κΈ° λλ¬Έμ
λλ€.
from torch.func import functional_call, vmap, grad
params = {k: v.detach() for k, v in model.named_parameters()}
buffers = {k: v.detach() for k, v in model.named_buffers()}
######################################################################
# λ€μμΌλ‘, λ°°μΉκ° μλ λ¨μΌ μ
λ ₯μ λν λͺ¨λΈμ μμ€μ κ³μ°νλ ν¨μλ₯Ό μ μν©μλ€.
# μ΄λ λ§€κ°λ³μ, μ
λ ₯, κ·Έλ¦¬κ³ λͺ©ν λ³μ(target)μ ν¨μκ° μΈμλ‘ λ°κ² νλ κ²μ΄ μ€μν©λλ€.
# κ·Έ μΈμλ€μ λν΄ λ³νμ μ μ©ν κ²μ΄κΈ° λλ¬Έμ
λλ€.
#
# μ°Έκ³ - λͺ¨λΈμ λ°°μΉλ₯Ό μλ μ²λ¦¬νλλ‘ μμ±λμμΌλ,
# ``torch.unsqueeze`` λ‘ λ°°μΉ μ°¨μμ μΆκ°ν΄ μ€λλ€.
def compute_loss(params, buffers, sample, target):
batch = sample.unsqueeze(0)
targets = target.unsqueeze(0)
predictions = functional_call(model, (params, buffers), (batch,))
loss = loss_fn(predictions, targets)
return loss
######################################################################
# μ΄μ , ``grad`` λ³νμ μ¬μ©ν΄μ ``compute_loss`` μ 첫λ²μ§Έ μΈμμΈ
# λ§€κ°λ³μ ``params`` μ λν΄ λ³νλλ₯Ό μΈ‘μ νλ μλ‘μ΄ ν¨μλ₯Ό λ§λ€μ΄ λ΄
μλ€.
ft_compute_grad = grad(compute_loss)
######################################################################
# ``ft_compute_grad`` λ λ¨μΌ (μν, λͺ©ν λ³μ) μμ λν λ³νλλ₯Ό κ³μ°νλ ν¨μμ
λλ€.
# ``vmap`` μ μ¬μ©νλ©΄ μ 체 μν λ° λͺ©ν λ³μλ€ λ°°μΉμ λν λ³νλλ₯Ό μ μ μμ΅λλ€.
# μ΄λ ``in_dims=(None, None, 0, 0)`` λ‘ μ€μ νλλ°, μ΄λ λ°μ΄ν°μ λͺ©ν λ³μμ
# 0λ²μ§Έ μ°¨μμ λν΄ ``ft_compute_grad`` λ₯Ό λ§€ννκ³ , λ§€κ°λ³μμ λ²νΌλ
# κ° μνμ λν΄ λκ°μ΄ μ¬μ©νκΈ° μν¨μ
λλ€.
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))
######################################################################
# λ§μ§λ§μΌλ‘, λ³νλ¨ ν¨μλ₯Ό μ΄μ©ν΄ μνλ³ λ³νλλ₯Ό κ³μ°ν©λλ€.
ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets)
######################################################################
# ``grad`` μ ``vmap`` μ μ¬μ©ν κ²°κ³Όκ° μλμΌλ‘ νλμ© μ²λ¦¬ν κ²°κ³Όμ μΌμΉνλμ§ νμΈν©λλ€:
for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads.values()):
assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=1.2e-1, rtol=1e-5)
######################################################################
# μ°Έκ³ : ``vmap`` μΌλ‘ λ³νν μ μλ ν¨μμ μ νμλ λͺ κ°μ§ μ νμ΄ μμ΅λλ€.
# κ°μ₯ λ³ννκΈ° μ’μ ν¨μλ μμ ν¨μ(pure function)μ
λλ€. μ¦, μΆλ ₯μ΄ μ€μ§ μ
λ ₯μ
# μν΄μλ§ κ²°μ λκ³ κ° μμ κ³Ό κ°μ λΆμμ©μ΄ μλ ν¨μμ
λλ€.
# ``vmap`` μ μμμ νμ΄μ¬ λ°μ΄ν° ꡬ쑰λ₯Ό μμ νλ μμ
μ μ²λ¦¬ν μ μμ§λ§,
# λ§μ PyTorch μΈνλ μ΄μ€(in-place) μ°μ°μ μ²λ¦¬ν μ μμ΅λλ€.
#
# μ±λ₯ λΉκ΅
# ----------------------
#
# ``vmap`` μ μ±λ₯μ΄ μΌλ§λ μ°¨μ΄ λλμ§ κΆκΈνμ κ°μ?
#
# νμ¬ A100(Ampere)κ³Ό κ°μ μ΅μ GPUμμ κ°μ₯ μ’μ κ²°κ³Όλ₯Ό μ»μ μ μκ³ ,
# μ΄ μμ μμ μ΅λ 25λ°°μ μλ ν₯μμ νμΈνμ΅λλ€. ννΈ λ€μμ λΉλ μ₯λΉμμμ κ²°κ³Όμ
λλ€:
def get_perf(first, first_descriptor, second, second_descriptor):
"""torch.benchmark κ°μ²΄λ€μ λ°μμ 첫λ²μ§Έμ λλ²μ§Έμ μ°¨μ΄λ₯Ό λΉκ΅ν©λλ€"""
second_res = second.times[0]
first_res = first.times[0]
gain = (first_res-second_res)/first_res
if gain < 0: gain *=-1
final_gain = gain*100
print(f"Performance delta: {final_gain:.4f} percent improvement with {first_descriptor} ")
from torch.utils.benchmark import Timer
without_vmap = Timer(stmt="compute_sample_grads(data, targets)", globals=globals())
with_vmap = Timer(stmt="ft_compute_sample_grad(params, buffers, data, targets)",globals=globals())
no_vmap_timing = without_vmap.timeit(100)
with_vmap_timing = with_vmap.timeit(100)
print(f'Per-sample-grads without vmap {no_vmap_timing}')
print(f'Per-sample-grads with vmap {with_vmap_timing}')
get_perf(with_vmap_timing, "vmap", no_vmap_timing, "no vmap")
######################################################################
# PyTorchμλ μ΄λ° λ°©λ²λ³΄λ€ μ±λ₯μ΄ λ°μ΄λ λ€λ₯Έ μ΅μ ν μ루μ
# (μ: https://github.com/pytorch/opacus)λ€λ μ‘΄μ¬ν©λλ€.
# νμ§λ§ ``vmap`` κ³Ό ``grad`` λ₯Ό μ‘°ν©νλ κ²λ§μΌλ‘ μ΄ μ λμ μλ ν₯μμ
# μ΄λ£° μ μλ€λ μ¬μ€μ΄ λ©μ§μ§ μλμ?
#
# μΌλ°μ μΌλ‘ ``vmap`` μ μ΄μ©ν 벑ν°νλ ν¨μλ₯Ό λ°λ³΅λ¬Έμμ μ€ννλ κ²λ³΄λ€ λΉ λ₯΄κ³ ,
# μλ λ°°μΉ(manual batching)κ³Ό κ²½μν λ§ν μ±λ₯μ 보μ¬μ€λλ€.
# λ€λ§ νΉμ μ°μ°μ λν΄ ``vmap`` κ·μΉμ΄ μκ±°λ, νμ 컀λμ΄ μ€λλ νλμ¨μ΄(GPU)μ
# μ΅μ νλμ§ μμ κ²½μ° λ±μμ μμΈκ° μμ μ μμ΅λλ€. μ΄λ° κ²½μ°λ₯Ό νλλΌλ λ°κ²¬νλ€λ©΄,
# GitHubμ μ΄μλ₯Ό λ¨κ²¨ μλ €μ£ΌμκΈ° λ°λλλ€.