Add real world dataset
This commit is contained in:
parent
bcb1e2c6ff
commit
cebf776714
123
data/dataset.py
123
data/dataset.py
@ -144,5 +144,128 @@ class TrackSynDataset(torchext.BaseDataset):
|
||||
return K
|
||||
|
||||
|
||||
class RealWorldDataset(torchext.BaseDataset):
|
||||
'''
|
||||
Load locally saved real-world dataset
|
||||
Please generate the dataset beforehand
|
||||
'''
|
||||
|
||||
def __init__(self, settings_path, sample_paths, track_length=1, train=True, data_aug=False):
|
||||
super().__init__(train=train)
|
||||
|
||||
self.settings_path = settings_path
|
||||
self.sample_paths = sample_paths
|
||||
self.data_aug = data_aug
|
||||
self.train = train
|
||||
self.track_length = track_length
|
||||
assert (track_length <= 4)
|
||||
|
||||
with open(str(settings_path), 'rb') as f:
|
||||
settings = pickle.load(f)
|
||||
self.imsizes = settings['imsizes']
|
||||
self.patterns = settings['patterns']
|
||||
self.focal_lengths = settings['focal_lengths']
|
||||
self.baseline = settings['baseline']
|
||||
self.K = settings['K']
|
||||
|
||||
self.scale = 1
|
||||
|
||||
self.max_shift = 0
|
||||
self.max_blur = 0.5
|
||||
self.max_noise = 3.0
|
||||
self.max_sp_noise = 0.0005
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sample_paths)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if not self.train:
|
||||
rng = self.get_rng(idx)
|
||||
else:
|
||||
rng = np.random.RandomState()
|
||||
sample_path = self.sample_paths[idx]
|
||||
|
||||
if self.train:
|
||||
track_ind = np.random.permutation(4)[0:self.track_length]
|
||||
else:
|
||||
track_ind = [0]
|
||||
|
||||
ret = {}
|
||||
ret['id'] = idx
|
||||
|
||||
# load imgs, at all scales
|
||||
for sidx in range(len(self.imsizes)):
|
||||
imgs = []
|
||||
ambs = []
|
||||
grads = []
|
||||
for tidx in track_ind:
|
||||
imgs.append(np.load(os.path.join(sample_path, f'im0.npy'), allow_pickle=True))
|
||||
ambs.append(np.load(os.path.join(sample_path, f'ambient0.npy'), allow_pickle=True))
|
||||
grads.append(np.load(os.path.join(sample_path, f'grad0.npy'), allow_pickle=True))
|
||||
ret[f'im0'] = np.stack(imgs, axis=0)
|
||||
ret[f'ambient0'] = np.stack(ambs, axis=0)
|
||||
ret[f'grad0'] = np.stack(grads, axis=0)
|
||||
|
||||
# load disp and grad only at full resolution
|
||||
# FIXME do this for our stuff
|
||||
disps = []
|
||||
R = []
|
||||
t = []
|
||||
for tidx in track_ind:
|
||||
disps.append(np.load(os.path.join(sample_path, f'disp0_{tidx}.npy'), allow_pickle=True))
|
||||
R.append(np.load(os.path.join(sample_path, f'R_{tidx}.npy'), allow_pickle=True))
|
||||
t.append(np.load(os.path.join(sample_path, f't_{tidx}.npy'), allow_pickle=True))
|
||||
ret[f'disp0'] = np.stack(disps, axis=0)
|
||||
ret['R'] = np.stack(R, axis=0)
|
||||
ret['t'] = np.stack(t, axis=0)
|
||||
|
||||
blend_im = np.load(os.path.join(sample_path, 'blend_im.npy'), allow_pickle=True)
|
||||
ret['blend_im'] = blend_im.astype(np.float32)
|
||||
|
||||
#### apply data augmentation at different scales seperately, only work for max_shift=0
|
||||
if self.data_aug:
|
||||
for sidx in range(len(self.imsizes)):
|
||||
if sidx == 0:
|
||||
img = ret[f'im{sidx}']
|
||||
disp = ret[f'disp{sidx}']
|
||||
grad = ret[f'grad{sidx}']
|
||||
img_aug = np.zeros_like(img)
|
||||
disp_aug = np.zeros_like(img)
|
||||
grad_aug = np.zeros_like(img)
|
||||
for i in range(img.shape[0]):
|
||||
img_aug_, disp_aug_, grad_aug_ = augment_image(img[i, 0], rng,
|
||||
disp=disp[i, 0], grad=grad[i, 0],
|
||||
max_shift=self.max_shift, max_blur=self.max_blur,
|
||||
max_noise=self.max_noise,
|
||||
max_sp_noise=self.max_sp_noise)
|
||||
img_aug[i] = img_aug_[None].astype(np.float32)
|
||||
disp_aug[i] = disp_aug_[None].astype(np.float32)
|
||||
grad_aug[i] = grad_aug_[None].astype(np.float32)
|
||||
ret[f'im{sidx}'] = img_aug
|
||||
ret[f'disp{sidx}'] = disp_aug
|
||||
ret[f'grad{sidx}'] = grad_aug
|
||||
else:
|
||||
img = ret[f'im{sidx}']
|
||||
img_aug = np.zeros_like(img)
|
||||
for i in range(img.shape[0]):
|
||||
img_aug_, _, _ = augment_image(img[i, 0], rng,
|
||||
max_shift=self.max_shift, max_blur=self.max_blur,
|
||||
max_noise=self.max_noise, max_sp_noise=self.max_sp_noise)
|
||||
img_aug[i] = img_aug_[None].astype(np.float32)
|
||||
ret[f'im{sidx}'] = img_aug
|
||||
|
||||
if len(track_ind) == 1:
|
||||
for key, val in ret.items():
|
||||
if key != 'blend_im' and key != 'id':
|
||||
ret[key] = val[0]
|
||||
|
||||
return ret
|
||||
|
||||
def getK(self, sidx=0):
|
||||
K = self.K.copy() / (2 ** sidx)
|
||||
K[2, 2] = 1
|
||||
return K
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user