본문 바로가기

잡것들

pytorch nan 디버깅 하는 법

torch.autograd.set_detect_anomaly(True)

이것을 설정하고 Nan이 발생하게 되면 RuntimeError가 발생되고 종료된다.

 

이 때 VSCode의 debugger를 사용해도 되고 

RuntimeError에 trace에 뜨는 부분을 직접 찾아가도 된다.

 

대표적으로 세가지 경우가 있다.

 

  • input에 nan이 있는 경우
  • forward시 network에 문제가 있는 경우
  • backpropagation 시 loss에 문제가 있는 경우

 

나의 경우 backpropagation에서 문제가 있었다. 

 

atan2의 경우 real, imag 값이 모두 0일때 backward를 수행하면 아래와 같이 gradient 가 nan이 나오게 된다.

x = torch.zeros(4, requires_grad=True)
y = torch.zeros(4)
out = torch.atan2(x, y)
out.mean().backward()
print(x.grad)
> tensor([nan, nan, nan, nan])

 

Loss 구현 시 

        targets_real = targets[:, 0]
        targets_imag = targets[:, 1]
        targets_phase = torch.atan2(targets_real, targets_imag)

위와 같이 구현해놓았었는데 이때 

targets_real 과 targets_imag 모두 0인 경우가 있을 수 있다. 음성 처리의 경우 무음 구간이 그러하다. 

		eps=1-e6
        targets_real = targets[:, 0] + eps
        targets_imag = targets[:, 1] + eps
        targets_phase = torch.atan2(targets_real, targets_imag)

이렇게 적절한 작은 값의 eps를 넣어주면 문제가 해결된다.

다만 target이 무음이 많은 경우 network는 45도의 phase로 추정하게끔 학습이 될 가능성이 있다.