-
Notifications
You must be signed in to change notification settings - Fork 1.2k
[New Feasture]: Add a FLOPs collection interface #1302
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -70,6 +70,11 @@ def forward(self, x, context=None, mask=None, pos_bias=None): | |
| k = self.k(context).view(b, -1, n, c) | ||
| v = self.v(context).view(b, -1, n, c) | ||
|
|
||
| # For caculate flops | ||
|
||
| self.q_shape = q.shape | ||
| self.k_shape = k.shape | ||
| self.v_shape = v.shape | ||
|
|
||
| # attention bias | ||
| attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) | ||
| if pos_bias is not None: | ||
|
|
@@ -327,4 +332,4 @@ def _clean(self, text): | |
| text = whitespace_clean(basic_clean(text)).lower() | ||
| elif self.clean == 'canonicalize': | ||
| text = canonicalize(basic_clean(text)) | ||
| return text | ||
| return text | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| from .flops_profiler import ( | ||
| profile_entire_model, | ||
| unprofile_entire_model, | ||
| get_flops, | ||
| print_model_profile, | ||
| ) |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,252 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch.nn as nn | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from functools import wraps | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import time | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from collections import defaultdict | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import flash_attn | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from einops import rearrange | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from torch.utils.flop_counter import conv_flop_count | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_dit_flops(model): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_dit_flops(dit_block_model): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| total_flops = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for sub_model in dit_block_model.modules(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| total_flops += getattr(sub_model, '__flops__', 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return total_flops | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| total_flops = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| total_duration = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for sub_module in model.modules(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if sub_module.__class__.__name__ == 'DiTBlock': | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| total_flops += get_dit_flops(sub_module) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| total_duration += getattr(sub_module, '__duration__', 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Tflops = total_flops / 1e12 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return Tflops | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_dit_flops(model): | |
| def get_dit_flops(dit_block_model): | |
| total_flops = 0 | |
| for sub_model in dit_block_model.modules(): | |
| total_flops += getattr(sub_model, '__flops__', 0) | |
| return total_flops | |
| total_flops = 0 | |
| total_duration = 0 | |
| for sub_module in model.modules(): | |
| if sub_module.__class__.__name__ == 'DiTBlock': | |
| total_flops += get_dit_flops(sub_module) | |
| total_duration += getattr(sub_module, '__duration__', 0) | |
| Tflops = total_flops / 1e12 | |
| return Tflops | |
| def get_dit_flops(model): | |
| def _get_dit_flops_recursive(dit_block_model): | |
| total_flops = 0 | |
| for sub_model in dit_block_model.modules(): | |
| total_flops += getattr(sub_model, '__flops__', 0) | |
| return total_flops | |
| total_flops = 0 | |
| for sub_module in model.modules(): | |
| if sub_module.__class__.__name__ == 'DiTBlock': | |
| total_flops += _get_dit_flops_recursive(sub_module) | |
| Tflops = total_flops / 1e12 | |
| return Tflops |
mahaocong90 marked this conversation as resolved.
Show resolved
Hide resolved
mahaocong90 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The signature of profiled_forward is (self, x, *args, **kwargs), which assumes that every wrapped module's forward method has a first positional argument x. This is not always true and can lead to a TypeError for modules with different forward signatures (e.g., no arguments, or keyword-only arguments). The signature should be (self, *args, **kwargs) to be generic and robust.
| def profiled_forward(self, x, *args, **kwargs): | |
| return module._original_forward(x, *args, **kwargs) | |
| def profiled_forward(self, *args, **kwargs): | |
| return module._original_forward(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is redefined on every iteration of the training loop, which is inefficient. It's better to define it once outside the loop. To do so, you'll need to pass the
timingdictionary as an argument.For example, you could define it before the loop:
And then call it inside the loop as
format_time(timing, "step").Since I cannot suggest changes outside of the current diff hunk, I'm leaving this as a comment for you to refactor.