You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

30 lines
956 B

5 years ago
import torch
import math
import numpy as np
from .functions import *
5 years ago
class CoordConv2d(torch.nn.Module):
def __init__(self, channels_in, channels_out, kernel_size, stride, padding):
super().__init__()
5 years ago
self.conv = torch.nn.Conv2d(channels_in + 2, channels_out, kernel_size=kernel_size, padding=padding,
stride=stride)
5 years ago
self.uv = None
5 years ago
def forward(self, x):
if self.uv is None:
height, width = x.shape[2], x.shape[3]
u, v = np.meshgrid(range(width), range(height))
u = 2 * u / (width - 1) - 1
v = 2 * v / (height - 1) - 1
uv = np.stack((u, v)).reshape(1, 2, height, width)
self.uv = torch.from_numpy(uv.astype(np.float32))
self.uv = self.uv.to(x.device)
uv = self.uv.expand(x.shape[0], *self.uv.shape[1:])
xuv = torch.cat((x, uv), dim=1)
y = self.conv(xuv)
return y