Skip to content

The calculation of w in SO3 to_quaternion() function may cause gradient explosion. #670

@HiOnes

Description

@HiOnes

🐛 Bug

I suffered gradient explosion in my training process. I used with autograd.detect_anomaly() and got the hint Function 'SqrtBackward0' returned nan values in its 0th output. However, I didn't use functions like torch.sqrt() in my code, so I thought the bug may lie in the internal calculations of Theseus. And I have noticed this #661 relevant fix, so I checked the implementation of to_quarternion() function in so3.py.

I found the eps in the calculation of sine_half_theta, the relevant code is:

sqrt_eps = _THESEUS_GLOBAL_PARAMS.get_eps("so3", "to_quaternion_sqrt", w.dtype)
sine_half_theta = (
    (0.5 * (1 - cosine_near_pi)).clamp(sqrt_eps, 1).sqrt().view(-1, 1)
)

However, another use of sqrt lies in the calculation of w:

w = 0.5 * (1 + self[:, 0, 0] + self[:, 1, 1] + self[:, 2, 2]).clamp(0, 4).sqrt()

here it just limits the result between 0 and 4, when it is close to 0, the backward process may fail.

Steps to Reproduce

I prepare a simple test code to reproduce this bug:

import theseus as th
import torch
import torch.nn.functional as F

rot = torch.tensor([[1.0, 0.0, 0.0],
                    [0.0, -1.0, 0.0],
                    [0.0, 0.0, -1.0]], requires_grad=True).reshape(1, 3, 3)
rot_so3 = th.SO3(tensor=rot)
identity_quat = torch.tensor([1.0, 0.0, 0.0, 0.0]).reshape(1, 4)
err = F.mse_loss(rot_so3.to_quaternion(), identity_quat)
rot.retain_grad()
err.backward()
print(rot.grad)

The output will be

tensor([[[-inf, 0., 0.],
         [0., -inf, 0.],
         [0., 0., -inf]]])

And if I add an eps(which is 1e-6 in my test) in the calculation of w:

w = 0.5 * (1 + self[:, 0, 0] + self[:, 1, 1] + self[:, 2, 2]).clamp(1e-6, 4).sqrt()

The grad will be:

tensor([[[-0.0625,  0.0000,  0.0000],
         [ 0.0000, -0.0625,  0.0000],
         [ 0.0000,  0.0000, -0.0625]]])

System Info

  • OS : Ubuntu 20.04
  • Python version: 3.8
  • CUDA version: 11.8

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions