Add wandb, make batch and image size configurable, fix some bugs, improve supervised loss function, make use of some RL stuff just implemented

master
Cpt.Captain 3 years ago
parent b7dbc59c25
commit e3303cf9d4
  1. 86
      model/exp_synph_real.py

@ -10,6 +10,7 @@ import matplotlib.pyplot as plt
import cv2 import cv2
import torchvision.transforms as transforms import torchvision.transforms as transforms
import wandb
import co import co
import torchext import torchext
@ -18,32 +19,39 @@ from data import dataset
class Worker(torchext.Worker): class Worker(torchext.Worker):
def __init__(self, args, num_workers=18, train_batch_size=4, test_batch_size=4, save_frequency=1, **kwargs): def __init__(self, args, num_workers=18, train_batch_size=6, test_batch_size=6, save_frequency=1, **kwargs):
if 'batch_size' in dir(args):
train_batch_size = args.batch_size
test_batch_size = args.batch_size
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers,
train_batch_size=train_batch_size, test_batch_size=test_batch_size, train_batch_size=train_batch_size, test_batch_size=test_batch_size,
save_frequency=save_frequency, **kwargs) save_frequency=save_frequency, **kwargs)
self.ms = args.ms self.ms = args.ms
self.pattern_path = args.pattern_path self.pattern_path = args.pattern_path
self.lcn_radius = args.lcn_radius self.lcn_radius = args.lcn_radius
self.dp_weight = args.dp_weight self.dp_weight = args.dp_weight
self.data_type = args.data_type self.data_type = args.data_type
self.imsizes = [(488, 648)]
for iter in range(3):
self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2)))
with open('config.json') as fp: with open('config.json') as fp:
config = json.load(fp) config = json.load(fp)
data_root = Path(config['DATA_ROOT']) data_root = Path(config['DATA_ROOT'])
self.imsizes = [tuple(map(int, config['IMSIZE'].split(',')))]
for iter in range(3):
self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2)))
self.settings_path = data_root / self.data_type / 'settings.pkl' self.settings_path = data_root / self.data_type / 'settings.pkl'
sample_paths = sorted((data_root / self.data_type).glob('0*/')) sample_paths = sorted((data_root / self.data_type).glob('0*/'))
self.train_paths = sample_paths[2 ** 10:] # calc split
self.test_paths = sample_paths[:2 ** 8] # since we don't have a lot or RL footage, we compute it as we go
train_size = len(sample_paths) * 0.8 // 1
test_size = 1 - train_size
self.train_paths = sample_paths[test_size:]
self.test_paths = sample_paths[:train_size]
# supervise the edge encoder with only 2**8 samples # don't just supervise the edge encoder with only 2**8 samples
self.train_edge = len(self.train_paths) - 2 ** 8 self.train_edge = len(self.train_paths)
self.lcn_in = networks.LCN(self.lcn_radius, 0.05) self.lcn_in = networks.LCN(self.lcn_radius, 0.05)
self.disparity_loss = networks.DisparityLoss() self.disparity_loss = networks.DisparityLoss()
@ -51,6 +59,25 @@ class Worker(torchext.Worker):
self.sup_disp_loss = torch.nn.MSELoss() self.sup_disp_loss = torch.nn.MSELoss()
self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device)) self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device))
# FIXME L2 Regularization, try it!!
# l2_lambda = 0.001
# l2_norm = sum(p.pow(2.0).sum()
# for p in net.parameters())
# self.sup_disp_loss = torch.nn.MSELoss() + l2_lambda * l2_norm
# FIXME try using log of this loss, otherwise it's very large compared to others
# self.sup_disp_loss = torch.nn.MSELoss()
class RMSLELoss(torch.nn.Module):
def __init__(self):
super().__init__()
self.mse = torch.nn.MSELoss()
def forward(self, pred, actual):
# FIXME rename this if log is better than sqrt
return torch.log(self.mse(pred, actual))
# return torch.sqrt(self.mse(torch.log(pred + 1), torch.log(actual + 1)))
self.sup_disp_loss = RMSLELoss()
# evaluate in the region where opencv Block Matching has valid values # evaluate in the region where opencv Block Matching has valid values
self.eval_mask = np.zeros(self.imsizes[0]) self.eval_mask = np.zeros(self.imsizes[0])
self.eval_mask[13:self.imsizes[0][0] - 13, 140:self.imsizes[0][1] - 13] = 1 self.eval_mask[13:self.imsizes[0][0] - 13, 140:self.imsizes[0][1] - 13] = 1
@ -59,14 +86,13 @@ class Worker(torchext.Worker):
self.eval_w = self.imsizes[0][1] - 13 - 140 self.eval_w = self.imsizes[0][1] - 13 - 140
def get_train_set(self): def get_train_set(self):
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, train_set = dataset.RealWorldDataset(self.settings_path, self.train_paths, train=True, data_aug=True,
track_length=1) track_length=1)
return train_set return train_set
def get_test_sets(self): def get_test_sets(self):
test_sets = torchext.TestSets() test_sets = torchext.TestSets()
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True, test_set = dataset.RealWorldDataset(self.settings_path, self.test_paths, train=False, data_aug=True,
track_length=1) track_length=1)
test_sets.append('simple', test_set, test_frequency=1) test_sets.append('simple', test_set, test_frequency=1)
@ -102,12 +128,12 @@ class Worker(torchext.Worker):
self.data[key_std] = im_std.to(device).detach() self.data[key_std] = im_std.to(device).detach()
def net_forward(self, net, train): def net_forward(self, net, train):
# FIXME hier schnibbeln?
out = net(self.data['im0']) out = net(self.data['im0'])
return out return out
def loss_forward(self, out, train): def loss_forward(self, out, train):
out, edge = out out, edge = out
losses = {}
if not (isinstance(out, tuple) or isinstance(out, list)): if not (isinstance(out, tuple) or isinstance(out, list)):
out = [out] out = [out]
if not (isinstance(edge, tuple) or isinstance(edge, list)): if not (isinstance(edge, tuple) or isinstance(edge, list)):
@ -118,23 +144,30 @@ class Worker(torchext.Worker):
# apply photometric loss # apply photometric loss
for s, l, o in zip(itertools.count(), self.losses, out): for s, l, o in zip(itertools.count(), self.losses, out):
val, pattern_proj = l(o[0], self.data[f'im{s}'][:, 0:1, ...], self.data[f'std{s}']) val, pattern_proj = l(o[0], self.data[f'im{s}'][:, 0:1, ...], self.data[f'std{s}'])
if s == 0: if s == 0:
self.pattern_proj = pattern_proj.detach() self.pattern_proj = pattern_proj.detach()
vals.append(val) vals.append(val)
losses['photometric'] = val
# apply disparity loss # 1-edge as ground truth edge if inverted
# 1-edge as ground truth edge if inversed
if isinstance(edge, tuple): if isinstance(edge, tuple):
edge0 = 1 - torch.sigmoid(edge[0][0]) edge0 = 1 - torch.sigmoid(edge[0][0])
else: else:
edge0 = 1 - torch.sigmoid(edge[0]) edge0 = 1 - torch.sigmoid(edge[0])
val = 0 val = 0
if isinstance(out[0], tuple): if isinstance(out[0], tuple):
# NOTE use supervised disparity loss sup_loss = self.sup_disp_loss(out[0][1], self.data['disp0'])
val += self.sup_disp_loss(out[0][1], self.data['disp0']) val += sup_loss
val += self.disparity_loss(out[0][0], edge0)
disp_loss = self.disparity_loss(out[0][0], edge0)
val += disp_loss
losses['GT Supervised disparity loss'] = sup_loss * self.dp_weight
losses['OG disparity loss'] = disp_loss * self.dp_weight
else: else:
val += self.disparity_loss(out[0], edge0) disp_loss = self.disparity_loss(out[0], edge0)
val += disp_loss
losses['OG disparity loss'] = disp_loss * self.dp_weight
if self.dp_weight > 0: if self.dp_weight > 0:
vals.append(val * self.dp_weight) vals.append(val * self.dp_weight)
@ -159,15 +192,20 @@ class Worker(torchext.Worker):
self.edge = e.detach() self.edge = e.detach()
self.edge = torch.sigmoid(self.edge) self.edge = torch.sigmoid(self.edge)
self.edge_gt = grad.detach() self.edge_gt = grad.detach()
losses['edge'] = val
vals.append(val) vals.append(val)
wandb.log(losses)
return vals return vals
def numpy_in_out(self, output): def numpy_in_out(self, output):
output, edge = output output, edge = output
if not (isinstance(output, tuple) or isinstance(output, list)): if not (isinstance(output, tuple) or isinstance(output, list)):
output = [output] output = [output]
if isinstance(output[0], tuple):
es = output[0][0].detach().to('cpu').numpy() es = output[0][0].detach().to('cpu').numpy()
else:
es = output[0].detach().to('cpu').numpy()
gt = self.data['disp0'].to('cpu').numpy().astype(np.float32) gt = self.data['disp0'].to('cpu').numpy().astype(np.float32)
im = self.data['im0'][:, 0:1, ...].detach().to('cpu').numpy() im = self.data['im0'][:, 0:1, ...].detach().to('cpu').numpy()
@ -250,6 +288,7 @@ class Worker(torchext.Worker):
plt.tight_layout() plt.tight_layout()
plt.savefig(str(out_path)) plt.savefig(str(out_path))
wandb.log({f'results_{"_".join(out_path.stem.split("_")[:-1])}': plt})
plt.close(fig) plt.close(fig)
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks=[]): def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks=[]):
@ -293,9 +332,4 @@ class Worker(torchext.Worker):
if __name__ == '__main__': if __name__ == '__main__':
# FIXME Nicolas fixe idee
# SGBM nutzen, um GT zu finden
# bei dispnet (oder w/e) letzte paar layer 'dublizieren' (zweiten head bauen) und so mehrere Loss funktionen gleichzeitig trainieren
# L1 + L2 und dann im selben Backwardspass optimieren
# für das ganze forward pass anpassen
pass pass

Loading…
Cancel
Save