行人重识别代码实战(三)

代码描述

代码来源: https://github.com/layumi/Person_reID_baseline_pytorch

详细信息可见README.md

准备好了训练数据和网络结构,下面就可以训练了:

1
2
3
4
5
6
7
python train.py --gpu_ids 0 --name ft_ResNet50 --train_all --batchsize 32  --data_dir your_data_path
--gpu_ids which gpu to run.
--name the name of the model.
--data_dir the path of the training data.
--train_all using all images to train.
--batchsize batch size.
--erasing_p random erasing probability.

这里探究一下train.py中都做了些什么。
首先是读取数据和label。这里使用了torch.utils.data.DataLoader, 可以获得两个迭代器dataloaders['train'] and dataloaders['val'] 来读数据:

1
2
3
4
5
6
7
8
9
10
image_datasets = {}
image_datasets['train'] = datasets.ImageFolder(os.path.join(data_dir, 'train'),
data_transforms['train'])
image_datasets['val'] = datasets.ImageFolder(os.path.join(data_dir, 'val'),
data_transforms['val'])

dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
shuffle=True, num_workers=8) # 8 workers may work faster
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

以下则是主要的代码来训练模型,一共只有20行:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# Iterate over data.
for data in dataloaders[phase]:
# get a batch of inputs
inputs, labels = data
now_batch_size,c,h,w = inputs.shape
if now_batch_size<opt.batchsize: # skip the last batch
continue
# print(inputs.shape)
# wrap them in Variable, if gpu is used, we transform the data to cuda.
if use_gpu:
inputs = Variable(inputs.cuda())
labels = Variable(labels.cuda())
else:
inputs, labels = Variable(inputs), Variable(labels)

# zero the parameter gradients
optimizer.zero_grad()

#-------- forward --------
outputs = model(inputs)
_, preds = torch.max(outputs.data, 1)
loss = criterion(outputs, labels)

#-------- backward + optimize --------
# only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()

每十轮,都会保存网络和更新loss曲线:

1
2
3
if epoch%10 == 9:
save_network(model, epoch)
draw_curve(epoch)

更多细节见train.py

文章作者: GeYu
文章链接: https://nuistgy.github.io/2019/08/24/行人重识别代码实战(三)/
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 Yu's Blog