master
Cpt.Captain 3 years ago
parent 3da66347e7
commit 9ed0c264f5
  1. 9
      torchext/worker.py

@ -14,6 +14,7 @@ import json
import matplotlib.pyplot as plt
import time
from collections import OrderedDict
import wandb
class StopWatch(object):
@ -81,7 +82,7 @@ class ETA(object):
class Worker(object):
def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16,
num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1):
num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1, no_double_heads=True):
self.out_root = Path(out_root)
self.experiment_name = experiment_name
self.epochs = epochs
@ -93,6 +94,7 @@ class Worker(object):
self.train_device = train_device
self.test_device = test_device
self.max_train_iter = max_train_iter
self.double_heads = no_double_heads
self.errs_list = []
@ -372,6 +374,7 @@ class Worker(object):
num_workers=self.num_workers, drop_last=True, pin_memory=False)
net = net.to(self.train_device)
wandb.watch(net)
net.train()
mean_loss = None
@ -396,6 +399,7 @@ class Worker(object):
stopwatch.start('loss')
errs = self.loss_forward(output, train=True)
if isinstance(errs, dict):
wandb.log(errs)
masks = errs['masks']
errs = errs['errs']
else:
@ -442,6 +446,7 @@ class Worker(object):
err_str = self.format_err_str(mean_loss)
logging.info(f'avg train_loss={err_str}')
wandb.log({'mean_loss': mean_loss})
return mean_loss
def callback_test_start(self, epoch, set_idx):
@ -495,6 +500,7 @@ class Worker(object):
stopwatch.start('loss')
errs = self.loss_forward(output, train=False)
if isinstance(errs, dict):
wandb.log(errs)
masks = errs['masks']
errs = errs['errs']
else:
@ -529,4 +535,5 @@ class Worker(object):
err_str = self.format_err_str(mean_loss)
logging.info(f'test epoch {epoch}: avg test_loss={err_str}')
wandb.log({'Test loss': mean_loss})
return mean_loss

Loading…
Cancel
Save