Add wandb, more args, increase learning rate

This commit is contained in:
Cpt.Captain 2022-02-22 13:35:50 +01:00
parent 168516924e
commit 3da66347e7

View File

@ -2,27 +2,41 @@ import os
import torch
from model import exp_synph
from model import exp_synphge
from model import exp_synph_real
from model import networks
from co.args import parse_args
import wandb
wandb.init(project="connecting_the_dots", entity="cpt-captain")
wandb.config.epochs = 100
wandb.config.batch_size = 3
# parse args
args = parse_args()
double_head = args.no_double_heads
wandb.config.update(args, allow_val_change=True)
# loss types
if args.loss == 'ph':
worker = exp_synph.Worker(args)
elif args.loss == 'phge':
worker = exp_synphge.Worker(args)
elif args.loss == 'phirl':
worker = exp_synph_real.Worker(args)
# double_head = False
# concatenation of original image and lcn image
channels_in = 2
# set up network
net = networks.DispEdgeDecoders(channels_in=channels_in, max_disp=args.max_disp, imsizes=worker.imsizes,
output_ms=worker.ms)
output_ms=worker.ms, double_head=double_head)
# optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
# optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4)
# start the work
worker.do(net, optimizer)