diff --git a/model/networks.py b/model/networks.py index a40c527..d6e06d9 100644 --- a/model/networks.py +++ b/model/networks.py @@ -387,7 +387,7 @@ class RectifiedPatternSimilarityLoss(TimedModule): if std is not None: mask = mask * std - diff = torchext.photometric_loss(pattern_proj.contiguous(), im.contiguous(), 9, self.loss_type, self.loss_eps) + diff = torchext.photometric_loss_pytorch(pattern_proj.contiguous(), im.contiguous(), 9, self.loss_type, self.loss_eps) val = (mask * diff).sum() / mask.sum() return val, pattern_proj