api_server.py: add missing device
This commit is contained in:
parent
8f0c3a32b8
commit
ed71c16912
@ -37,7 +37,7 @@ model_path = "../train_log/models/latest.pth"
|
||||
device = torch.device('cuda:0')
|
||||
|
||||
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
|
||||
model = nn.DataParallel(model, device_ids=[])
|
||||
model = nn.DataParallel(model, device_ids=[device])
|
||||
# model.load_state_dict(torch.load(model_path), strict=False)
|
||||
state_dict = torch.load(model_path)['state_dict']
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
|
Loading…
Reference in New Issue
Block a user