-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Expand file tree
/
Copy pathtrain.py
More file actions
executable file
·366 lines (308 loc) · 12.5 KB
/
train.py
File metadata and controls
executable file
·366 lines (308 loc) · 12.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import sys
import logging
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid import profiler
import reader
from utils import *
import models
from build_model import create_model
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TimeAverager(object):
def __init__(self):
self.reset()
def reset(self):
self._cnt = 0
self._total_time = 0
def record(self, usetime):
self._cnt += 1
self._total_time += usetime
def get_average(self):
if self._cnt == 0:
return 0
return self._total_time / self._cnt
def build_program(is_train, main_prog, startup_prog, args):
"""build program, and add backward op in program accroding to different mode
Parameters:
is_train: indicate train mode or test mode
main_prog: main program
startup_prog: strartup program
args: arguments
Returns :
train mode: [Loss, global_lr, data_loader]
test mode: [Loss, data_loader]
"""
if args.model.startswith('EfficientNet'):
override_params = {"drop_connect_rate": args.drop_connect_rate}
padding_type = args.padding_type
use_se = args.use_se
model = models.__dict__[args.model](is_test=not is_train,
override_params=override_params,
padding_type=padding_type,
use_se=use_se)
else:
model = models.__dict__[args.model]()
optimizer = None
with fluid.program_guard(main_prog, startup_prog):
if args.random_seed or args.enable_ce:
main_prog.random_seed = args.random_seed
startup_prog.random_seed = args.random_seed
with fluid.unique_name.guard():
data_loader, loss_out = create_model(model, args, is_train)
# add backward op in program
if is_train:
optimizer = create_optimizer(args)
avg_cost = loss_out[0]
#XXX: fetch learning rate now, better implement is required here.
global_lr = optimizer._global_learning_rate()
global_lr.persistable = True
loss_out.append(global_lr)
if args.use_amp:
optimizer = paddle.static.amp.decorate(
optimizer,
init_loss_scaling=args.scale_loss,
use_dynamic_loss_scaling=args.use_dynamic_loss_scaling,
use_pure_fp16=args.use_pure_fp16,
use_fp16_guard=True)
elif args.use_amp_bf16:
optimizer = paddle.static.amp.bf16.decorate_bf16(
optimizer,
amp_lists=paddle.static.amp.bf16.
AutoMixedPrecisionListsBF16(
custom_bf16_list={"conv2d"}),
use_bf16_guard=None,
use_pure_bf16=args.use_pure_bf16)
optimizer.minimize(avg_cost)
if args.use_ema:
global_steps = fluid.layers.learning_rate_scheduler._decay_step_counter(
)
ema = ExponentialMovingAverage(
args.ema_decay, thres_steps=global_steps)
ema.update()
loss_out.append(ema)
loss_out.append(data_loader)
return loss_out, optimizer
def validate(args,
test_iter,
exe,
test_prog,
test_fetch_list,
pass_id,
train_batch_metrics_record,
train_batch_time_record=None,
train_prog=None):
test_batch_time_record = []
test_batch_metrics_record = []
test_batch_id = 0
if int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) > 1:
compiled_program = test_prog
else:
compiled_program = best_strategy_compiled(
args,
test_prog,
test_fetch_list[0],
exe,
mode="val",
share_prog=train_prog)
for batch in test_iter:
t1 = time.time()
test_batch_metrics = exe.run(program=compiled_program,
feed=batch,
fetch_list=test_fetch_list)
t2 = time.time()
test_batch_elapse = t2 - t1
test_batch_time_record.append(test_batch_elapse)
test_batch_metrics_avg = np.mean(np.array(test_batch_metrics), axis=1)
test_batch_metrics_record.append(test_batch_metrics_avg)
print_info("batch", test_batch_metrics_avg, test_batch_elapse, pass_id,
test_batch_id, args.print_step, args.class_dim)
sys.stdout.flush()
test_batch_id += 1
train_epoch_metrics_avg = np.mean(
np.array(train_batch_metrics_record), axis=0)
test_epoch_time_avg = np.mean(np.array(test_batch_time_record))
test_epoch_metrics_avg = np.mean(
np.array(test_batch_metrics_record), axis=0)
print_info(
"epoch",
list(train_epoch_metrics_avg) + list(test_epoch_metrics_avg),
test_epoch_time_avg,
pass_id=pass_id,
class_dim=args.class_dim)
if args.enable_ce:
device_num = fluid.core.get_cuda_device_count() if args.use_gpu else 1
print_info(
"ce",
list(train_epoch_metrics_avg) + list(test_epoch_metrics_avg),
train_batch_time_record,
device_num=device_num)
def train(args):
"""Train model
Args:
args: all arguments.
"""
startup_prog = fluid.Program()
train_prog = fluid.Program()
train_out, optimizer = build_program(
is_train=True,
main_prog=train_prog,
startup_prog=startup_prog,
args=args)
train_data_loader = train_out[-1]
if args.use_ema:
train_fetch_vars = train_out[:-2]
ema = train_out[-2]
else:
train_fetch_vars = train_out[:-1]
train_fetch_list = [var.name for var in train_fetch_vars]
if args.validate:
test_prog = fluid.Program()
test_out, _ = build_program(
is_train=False,
main_prog=test_prog,
startup_prog=startup_prog,
args=args)
test_data_loader = test_out[-1]
test_fetch_vars = test_out[:-1]
test_fetch_list = [var.name for var in test_fetch_vars]
#Create test_prog and set layers' is_test params to True
test_prog = test_prog.clone(for_test=True)
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = fluid.CUDAPlace(gpu_id) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_prog)
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0))
#init model by checkpoint or pretrianed model.
init_model(exe, args, train_prog)
if args.use_amp or args.use_amp_bf16:
optimizer.amp_init(
place,
scope=paddle.static.global_scope(),
test_program=test_prog if args.validate else None)
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
if args.use_dali:
import dali
train_iter = dali.train(settings=args)
if trainer_id == 0:
test_iter = dali.val(settings=args)
else:
imagenet_reader = reader.ImageNetReader(0 if num_trainers > 1 else None)
train_reader = imagenet_reader.train(settings=args)
if args.use_gpu:
if num_trainers <= 1:
places = fluid.framework.cuda_places()
else:
places = place
else:
if num_trainers <= 1:
places = fluid.framework.cpu_places()
else:
places = place
train_data_loader.set_sample_list_generator(train_reader, places)
if args.validate:
test_reader = imagenet_reader.val(settings=args)
test_data_loader.set_sample_list_generator(test_reader, places)
compiled_train_prog = best_strategy_compiled(args, train_prog,
train_fetch_vars[0], exe)
#NOTE: this for benchmark
total_batch_num = 0
batch_cost_averager = TimeAverager()
reader_cost_averager = TimeAverager()
for pass_id in range(args.num_epochs):
if num_trainers > 1 and not args.use_dali:
imagenet_reader.set_shuffle_seed(pass_id + (
args.random_seed if args.random_seed else 0))
train_batch_id = 0
train_batch_time_record = []
train_batch_metrics_record = []
if not args.use_dali:
train_iter = train_data_loader()
if args.validate:
test_iter = test_data_loader()
batch_start = time.time()
for batch in train_iter:
#NOTE: this is for benchmark
if args.max_iter and total_batch_num == args.max_iter:
return
reader_cost_averager.record(time.time() - batch_start)
train_batch_metrics = exe.run(compiled_train_prog,
feed=batch,
fetch_list=train_fetch_list)
train_batch_metrics_avg = np.mean(
np.array(train_batch_metrics), axis=1)
train_batch_metrics_record.append(train_batch_metrics_avg)
# Record the time for ce and benchmark
train_batch_elapse = time.time() - batch_start
train_batch_time_record.append(train_batch_elapse)
batch_cost_averager.record(train_batch_elapse)
if trainer_id == 0:
ips = float(args.batch_size) / batch_cost_averager.get_average()
print_info(
"batch",
train_batch_metrics_avg,
batch_cost_averager.get_average(),
pass_id,
train_batch_id,
args.print_step,
reader_cost=reader_cost_averager.get_average(),
ips=ips)
sys.stdout.flush()
if train_batch_id % args.print_step == 0:
batch_cost_averager.reset()
reader_cost_averager.reset()
train_batch_id += 1
total_batch_num = total_batch_num + 1
batch_start = time.time()
#NOTE: this for benchmark profiler
if args.is_profiler and pass_id == 0 and train_batch_id == args.print_step:
profiler.start_profiler("All")
elif args.is_profiler and pass_id == 0 and train_batch_id == args.print_step + 5:
profiler.stop_profiler("total", args.profiler_path)
return
if args.use_dali:
train_iter.reset()
if trainer_id == 0 and args.validate:
if args.use_ema:
logger.info('ExponentialMovingAverage validate start...')
with ema.apply(exe):
validate(args, test_iter, exe, test_prog, test_fetch_list,
pass_id, train_batch_metrics_record,
compiled_train_prog)
logger.info('ExponentialMovingAverage validate over!')
validate(args, test_iter, exe, test_prog, test_fetch_list, pass_id,
train_batch_metrics_record, train_batch_time_record,
compiled_train_prog)
if args.use_dali:
test_iter.reset()
if trainer_id == 0 and pass_id % args.save_step == 0:
save_model(args, exe, train_prog, pass_id)
def main():
args = parse_args()
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
print_arguments(args)
check_args(args)
train(args)
if __name__ == '__main__':
import paddle
paddle.enable_static()
main()