From e3303cf9d4eb40168e4aebd38d7fe1e0b6ad2532 Mon Sep 17 00:00:00 2001 From: "Cpt.Captain" Date: Tue, 22 Feb 2022 13:32:28 +0100 Subject: [PATCH] Add wandb, make batch and image size configurable, fix some bugs, improve supervised loss function, make use of some RL stuff just implemented --- model/exp_synph_real.py | 88 ++++++++++++++++++++++++++++------------- 1 file changed, 61 insertions(+), 27 deletions(-) diff --git a/model/exp_synph_real.py b/model/exp_synph_real.py index 5f25cc1..4f2d01f 100644 --- a/model/exp_synph_real.py +++ b/model/exp_synph_real.py @@ -10,6 +10,7 @@ import matplotlib.pyplot as plt import cv2 import torchvision.transforms as transforms +import wandb import co import torchext @@ -18,32 +19,39 @@ from data import dataset class Worker(torchext.Worker): - def __init__(self, args, num_workers=18, train_batch_size=4, test_batch_size=4, save_frequency=1, **kwargs): + def __init__(self, args, num_workers=18, train_batch_size=6, test_batch_size=6, save_frequency=1, **kwargs): + if 'batch_size' in dir(args): + train_batch_size = args.batch_size + test_batch_size = args.batch_size super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, train_batch_size=train_batch_size, test_batch_size=test_batch_size, save_frequency=save_frequency, **kwargs) - self.ms = args.ms self.pattern_path = args.pattern_path self.lcn_radius = args.lcn_radius self.dp_weight = args.dp_weight self.data_type = args.data_type - self.imsizes = [(488, 648)] - for iter in range(3): - self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2))) - with open('config.json') as fp: config = json.load(fp) data_root = Path(config['DATA_ROOT']) + self.imsizes = [tuple(map(int, config['IMSIZE'].split(',')))] + + for iter in range(3): + self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2))) + self.settings_path = data_root / self.data_type / 'settings.pkl' sample_paths = sorted((data_root / self.data_type).glob('0*/')) - self.train_paths = sample_paths[2 ** 10:] - self.test_paths = sample_paths[:2 ** 8] + # calc split + # since we don't have a lot or RL footage, we compute it as we go + train_size = len(sample_paths) * 0.8 // 1 + test_size = 1 - train_size + self.train_paths = sample_paths[test_size:] + self.test_paths = sample_paths[:train_size] - # supervise the edge encoder with only 2**8 samples - self.train_edge = len(self.train_paths) - 2 ** 8 + # don't just supervise the edge encoder with only 2**8 samples + self.train_edge = len(self.train_paths) self.lcn_in = networks.LCN(self.lcn_radius, 0.05) self.disparity_loss = networks.DisparityLoss() @@ -51,6 +59,25 @@ class Worker(torchext.Worker): self.sup_disp_loss = torch.nn.MSELoss() self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device)) + # FIXME L2 Regularization, try it!! + # l2_lambda = 0.001 + # l2_norm = sum(p.pow(2.0).sum() + # for p in net.parameters()) + # self.sup_disp_loss = torch.nn.MSELoss() + l2_lambda * l2_norm + # FIXME try using log of this loss, otherwise it's very large compared to others + # self.sup_disp_loss = torch.nn.MSELoss() + class RMSLELoss(torch.nn.Module): + def __init__(self): + super().__init__() + self.mse = torch.nn.MSELoss() + + def forward(self, pred, actual): + # FIXME rename this if log is better than sqrt + return torch.log(self.mse(pred, actual)) + # return torch.sqrt(self.mse(torch.log(pred + 1), torch.log(actual + 1))) + + self.sup_disp_loss = RMSLELoss() + # evaluate in the region where opencv Block Matching has valid values self.eval_mask = np.zeros(self.imsizes[0]) self.eval_mask[13:self.imsizes[0][0] - 13, 140:self.imsizes[0][1] - 13] = 1 @@ -59,14 +86,13 @@ class Worker(torchext.Worker): self.eval_w = self.imsizes[0][1] - 13 - 140 def get_train_set(self): - train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, + train_set = dataset.RealWorldDataset(self.settings_path, self.train_paths, train=True, data_aug=True, track_length=1) - return train_set def get_test_sets(self): test_sets = torchext.TestSets() - test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True, + test_set = dataset.RealWorldDataset(self.settings_path, self.test_paths, train=False, data_aug=True, track_length=1) test_sets.append('simple', test_set, test_frequency=1) @@ -102,12 +128,12 @@ class Worker(torchext.Worker): self.data[key_std] = im_std.to(device).detach() def net_forward(self, net, train): - # FIXME hier schnibbeln? out = net(self.data['im0']) return out def loss_forward(self, out, train): out, edge = out + losses = {} if not (isinstance(out, tuple) or isinstance(out, list)): out = [out] if not (isinstance(edge, tuple) or isinstance(edge, list)): @@ -118,23 +144,30 @@ class Worker(torchext.Worker): # apply photometric loss for s, l, o in zip(itertools.count(), self.losses, out): val, pattern_proj = l(o[0], self.data[f'im{s}'][:, 0:1, ...], self.data[f'std{s}']) + if s == 0: self.pattern_proj = pattern_proj.detach() vals.append(val) - - # apply disparity loss - # 1-edge as ground truth edge if inversed + losses['photometric'] = val + # 1-edge as ground truth edge if inverted if isinstance(edge, tuple): edge0 = 1 - torch.sigmoid(edge[0][0]) else: edge0 = 1 - torch.sigmoid(edge[0]) val = 0 if isinstance(out[0], tuple): - # NOTE use supervised disparity loss - val += self.sup_disp_loss(out[0][1], self.data['disp0']) - val += self.disparity_loss(out[0][0], edge0) + sup_loss = self.sup_disp_loss(out[0][1], self.data['disp0']) + val += sup_loss + + disp_loss = self.disparity_loss(out[0][0], edge0) + val += disp_loss + + losses['GT Supervised disparity loss'] = sup_loss * self.dp_weight + losses['OG disparity loss'] = disp_loss * self.dp_weight else: - val += self.disparity_loss(out[0], edge0) + disp_loss = self.disparity_loss(out[0], edge0) + val += disp_loss + losses['OG disparity loss'] = disp_loss * self.dp_weight if self.dp_weight > 0: vals.append(val * self.dp_weight) @@ -159,15 +192,20 @@ class Worker(torchext.Worker): self.edge = e.detach() self.edge = torch.sigmoid(self.edge) self.edge_gt = grad.detach() + losses['edge'] = val vals.append(val) + wandb.log(losses) return vals def numpy_in_out(self, output): output, edge = output if not (isinstance(output, tuple) or isinstance(output, list)): output = [output] - es = output[0][0].detach().to('cpu').numpy() + if isinstance(output[0], tuple): + es = output[0][0].detach().to('cpu').numpy() + else: + es = output[0].detach().to('cpu').numpy() gt = self.data['disp0'].to('cpu').numpy().astype(np.float32) im = self.data['im0'][:, 0:1, ...].detach().to('cpu').numpy() @@ -250,6 +288,7 @@ class Worker(torchext.Worker): plt.tight_layout() plt.savefig(str(out_path)) + wandb.log({f'results_{"_".join(out_path.stem.split("_")[:-1])}': plt}) plt.close(fig) def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks=[]): @@ -293,9 +332,4 @@ class Worker(torchext.Worker): if __name__ == '__main__': - # FIXME Nicolas fixe idee - # SGBM nutzen, um GT zu finden - # bei dispnet (oder w/e) letzte paar layer 'dublizieren' (zweiten head bauen) und so mehrere Loss funktionen gleichzeitig trainieren - # L1 + L2 und dann im selben Backwardspass optimieren - # für das ganze forward pass anpassen pass