Add wandb
This commit is contained in:
parent
3da66347e7
commit
9ed0c264f5
@ -14,6 +14,7 @@ import json
|
||||
import matplotlib.pyplot as plt
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
import wandb
|
||||
|
||||
|
||||
class StopWatch(object):
|
||||
@ -81,7 +82,7 @@ class ETA(object):
|
||||
|
||||
class Worker(object):
|
||||
def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16,
|
||||
num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1):
|
||||
num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1, no_double_heads=True):
|
||||
self.out_root = Path(out_root)
|
||||
self.experiment_name = experiment_name
|
||||
self.epochs = epochs
|
||||
@ -93,6 +94,7 @@ class Worker(object):
|
||||
self.train_device = train_device
|
||||
self.test_device = test_device
|
||||
self.max_train_iter = max_train_iter
|
||||
self.double_heads = no_double_heads
|
||||
|
||||
self.errs_list = []
|
||||
|
||||
@ -372,6 +374,7 @@ class Worker(object):
|
||||
num_workers=self.num_workers, drop_last=True, pin_memory=False)
|
||||
|
||||
net = net.to(self.train_device)
|
||||
wandb.watch(net)
|
||||
net.train()
|
||||
|
||||
mean_loss = None
|
||||
@ -396,6 +399,7 @@ class Worker(object):
|
||||
stopwatch.start('loss')
|
||||
errs = self.loss_forward(output, train=True)
|
||||
if isinstance(errs, dict):
|
||||
wandb.log(errs)
|
||||
masks = errs['masks']
|
||||
errs = errs['errs']
|
||||
else:
|
||||
@ -442,6 +446,7 @@ class Worker(object):
|
||||
|
||||
err_str = self.format_err_str(mean_loss)
|
||||
logging.info(f'avg train_loss={err_str}')
|
||||
wandb.log({'mean_loss': mean_loss})
|
||||
return mean_loss
|
||||
|
||||
def callback_test_start(self, epoch, set_idx):
|
||||
@ -495,6 +500,7 @@ class Worker(object):
|
||||
stopwatch.start('loss')
|
||||
errs = self.loss_forward(output, train=False)
|
||||
if isinstance(errs, dict):
|
||||
wandb.log(errs)
|
||||
masks = errs['masks']
|
||||
errs = errs['errs']
|
||||
else:
|
||||
@ -529,4 +535,5 @@ class Worker(object):
|
||||
|
||||
err_str = self.format_err_str(mean_loss)
|
||||
logging.info(f'test epoch {epoch}: avg test_loss={err_str}')
|
||||
wandb.log({'Test loss': mean_loss})
|
||||
return mean_loss
|
||||
|
Loading…
Reference in New Issue
Block a user