Add wandb, make batch and image size configurable, fix some bugs

This commit is contained in:
Cpt.Captain 2022-02-22 13:32:50 +01:00
parent e3303cf9d4
commit 7633990c81

View File

@ -12,12 +12,18 @@ import torchext
from model import networks
from data import dataset
import wandb
class Worker(torchext.Worker):
def __init__(self, args, num_workers=18, train_batch_size=6, test_batch_size=6, save_frequency=1, **kwargs):
def __init__(self, args, num_workers=18, train_batch_size=2, test_batch_size=2, save_frequency=1, **kwargs):
if 'batch_size' in dir(args):
train_batch_size = args.batch_size
test_batch_size = args.batch_size
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers,
train_batch_size=train_batch_size, test_batch_size=test_batch_size,
save_frequency=save_frequency, **kwargs)
print(args.no_double_heads)
self.ms = args.ms
self.pattern_path = args.pattern_path
@ -28,13 +34,15 @@ class Worker(torchext.Worker):
self.data_type = args.data_type
assert (self.track_length > 1)
self.imsizes = [(488, 648)]
for iter in range(3):
self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2)))
with open('config.json') as fp:
config = json.load(fp)
data_root = Path(config['DATA_ROOT'])
self.imsizes = [tuple(map(int, config['IMSIZE'].split(',')))]
for iter in range(3):
self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2)))
self.settings_path = data_root / self.data_type / 'settings.pkl'
sample_paths = sorted((data_root / self.data_type).glob('0*/'))
@ -84,7 +92,9 @@ class Worker(torchext.Worker):
Ki = np.linalg.inv(K)
K = torch.from_numpy(K)
Ki = torch.from_numpy(Ki)
ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.1)
# FIXME why would i need to increase this?
ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.5)
# ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.1)
self.ph_losses.append(ph_loss)
self.ge_losses.append(ge_loss)
@ -130,7 +140,8 @@ class Worker(torchext.Worker):
out = out.view(tl, bs, *out.shape[1:])
edge = edge.view(tl, bs, *out.shape[1:])
else:
out = [o.view(tl, bs, *o.shape[1:]) for o in out]
out = [o[0].view(tl, bs, *o[0].shape[1:]) for o in out]
# out = [o.view(tl, bs, *o.shape[1:]) for o in out]
edge = [e.view(tl, bs, *e.shape[1:]) for e in edge]
return out, edge
@ -140,6 +151,7 @@ class Worker(torchext.Worker):
out = [out]
vals = []
diffs = []
losses = {}
# apply photometric loss
for s, l, o in zip(itertools.count(), self.ph_losses, out):
@ -149,6 +161,7 @@ class Worker(torchext.Worker):
std = self.data[f'std{s}']
std = std.view(-1, *std.shape[2:])
val, pattern_proj = l(o, im[:, 0:1, ...], std)
losses['photometric'] = val
vals.append(val)
if s == 0:
self.pattern_proj = pattern_proj.detach()
@ -159,6 +172,7 @@ class Worker(torchext.Worker):
edge0 = edge0.view(-1, *edge0.shape[2:])
out0 = out[0].view(-1, *out[0].shape[2:])
val = self.disparity_loss(out0, edge0)
losses['disparity'] = val * self.dp_weight
if self.dp_weight > 0:
vals.append(val * self.dp_weight)
@ -177,6 +191,7 @@ class Worker(torchext.Worker):
val = self.edge_loss(e, grad)
else:
val = torch.zeros_like(vals[0])
losses['edge loss'] = val
vals.append(val)
if train is False:
@ -201,8 +216,10 @@ class Worker(torchext.Worker):
t1 = t[tidx1]
val = ge_loss(depth0, depth1, R0, t0, R1, t1)
losses['geometric loss'] = val
vals.append(val * self.ge_weight / ge_num)
wandb.log(losses)
return vals
def numpy_in_out(self, output):
@ -287,6 +304,7 @@ class Worker(torchext.Worker):
plt.tight_layout()
plt.savefig(str(out_path))
wandb.log({f'results_{"_".join(out_path.stem.split("_")[:-1])}': plt})
plt.close(fig)
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):