Add wandb, make batch and image size configurable, fix some bugs
This commit is contained in:
parent
e3303cf9d4
commit
7633990c81
@ -12,12 +12,18 @@ import torchext
|
||||
from model import networks
|
||||
from data import dataset
|
||||
|
||||
import wandb
|
||||
|
||||
|
||||
class Worker(torchext.Worker):
|
||||
def __init__(self, args, num_workers=18, train_batch_size=6, test_batch_size=6, save_frequency=1, **kwargs):
|
||||
def __init__(self, args, num_workers=18, train_batch_size=2, test_batch_size=2, save_frequency=1, **kwargs):
|
||||
if 'batch_size' in dir(args):
|
||||
train_batch_size = args.batch_size
|
||||
test_batch_size = args.batch_size
|
||||
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers,
|
||||
train_batch_size=train_batch_size, test_batch_size=test_batch_size,
|
||||
save_frequency=save_frequency, **kwargs)
|
||||
print(args.no_double_heads)
|
||||
|
||||
self.ms = args.ms
|
||||
self.pattern_path = args.pattern_path
|
||||
@ -28,13 +34,15 @@ class Worker(torchext.Worker):
|
||||
self.data_type = args.data_type
|
||||
assert (self.track_length > 1)
|
||||
|
||||
self.imsizes = [(488, 648)]
|
||||
for iter in range(3):
|
||||
self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2)))
|
||||
|
||||
with open('config.json') as fp:
|
||||
config = json.load(fp)
|
||||
data_root = Path(config['DATA_ROOT'])
|
||||
self.imsizes = [tuple(map(int, config['IMSIZE'].split(',')))]
|
||||
|
||||
for iter in range(3):
|
||||
self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2)))
|
||||
|
||||
|
||||
self.settings_path = data_root / self.data_type / 'settings.pkl'
|
||||
sample_paths = sorted((data_root / self.data_type).glob('0*/'))
|
||||
|
||||
@ -84,7 +92,9 @@ class Worker(torchext.Worker):
|
||||
Ki = np.linalg.inv(K)
|
||||
K = torch.from_numpy(K)
|
||||
Ki = torch.from_numpy(Ki)
|
||||
ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.1)
|
||||
# FIXME why would i need to increase this?
|
||||
ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.5)
|
||||
# ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.1)
|
||||
|
||||
self.ph_losses.append(ph_loss)
|
||||
self.ge_losses.append(ge_loss)
|
||||
@ -130,7 +140,8 @@ class Worker(torchext.Worker):
|
||||
out = out.view(tl, bs, *out.shape[1:])
|
||||
edge = edge.view(tl, bs, *out.shape[1:])
|
||||
else:
|
||||
out = [o.view(tl, bs, *o.shape[1:]) for o in out]
|
||||
out = [o[0].view(tl, bs, *o[0].shape[1:]) for o in out]
|
||||
# out = [o.view(tl, bs, *o.shape[1:]) for o in out]
|
||||
edge = [e.view(tl, bs, *e.shape[1:]) for e in edge]
|
||||
return out, edge
|
||||
|
||||
@ -140,6 +151,7 @@ class Worker(torchext.Worker):
|
||||
out = [out]
|
||||
vals = []
|
||||
diffs = []
|
||||
losses = {}
|
||||
|
||||
# apply photometric loss
|
||||
for s, l, o in zip(itertools.count(), self.ph_losses, out):
|
||||
@ -149,6 +161,7 @@ class Worker(torchext.Worker):
|
||||
std = self.data[f'std{s}']
|
||||
std = std.view(-1, *std.shape[2:])
|
||||
val, pattern_proj = l(o, im[:, 0:1, ...], std)
|
||||
losses['photometric'] = val
|
||||
vals.append(val)
|
||||
if s == 0:
|
||||
self.pattern_proj = pattern_proj.detach()
|
||||
@ -159,6 +172,7 @@ class Worker(torchext.Worker):
|
||||
edge0 = edge0.view(-1, *edge0.shape[2:])
|
||||
out0 = out[0].view(-1, *out[0].shape[2:])
|
||||
val = self.disparity_loss(out0, edge0)
|
||||
losses['disparity'] = val * self.dp_weight
|
||||
if self.dp_weight > 0:
|
||||
vals.append(val * self.dp_weight)
|
||||
|
||||
@ -177,6 +191,7 @@ class Worker(torchext.Worker):
|
||||
val = self.edge_loss(e, grad)
|
||||
else:
|
||||
val = torch.zeros_like(vals[0])
|
||||
losses['edge loss'] = val
|
||||
vals.append(val)
|
||||
|
||||
if train is False:
|
||||
@ -201,8 +216,10 @@ class Worker(torchext.Worker):
|
||||
t1 = t[tidx1]
|
||||
|
||||
val = ge_loss(depth0, depth1, R0, t0, R1, t1)
|
||||
losses['geometric loss'] = val
|
||||
vals.append(val * self.ge_weight / ge_num)
|
||||
|
||||
wandb.log(losses)
|
||||
return vals
|
||||
|
||||
def numpy_in_out(self, output):
|
||||
@ -287,6 +304,7 @@ class Worker(torchext.Worker):
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(str(out_path))
|
||||
wandb.log({f'results_{"_".join(out_path.stem.split("_")[:-1])}': plt})
|
||||
plt.close(fig)
|
||||
|
||||
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):
|
||||
|
Loading…
Reference in New Issue
Block a user