行人重识别代码实战(二)

代码描述

代码来源: https://github.com/layumi/Person_reID_baseline_pytorch

详细信息可见README.md

这次研究的是model.py原理是利用和修改预训练模型,代码原作者使用的是ImageNet预训练网络。

pytorch里引入方式如下:

1
2
from torchvision import models
model =models.resnet50(pretrained = True)

通过print(model)查看网络结构:

实际使用时要做修改。考虑到Market1501训练集中有751个不同的人,所以要改变模型来训练Reid的分类器:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# Define the ResNet50-based Model
class ft_net(nn.Module):

def __init__(self, class_num, droprate=0.5, stride=2):
super(ft_net, self).__init__()
model_ft = models.resnet50(pretrained=True)
# avg pooling to global pooling
if stride == 1:
model_ft.layer4[0].downsample[0].stride = (1,1)
model_ft.layer4[0].conv2.stride = (1,1)
model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1))
self.model = model_ft
self.classifier = ClassBlock(2048, class_num, droprate)

def forward(self, x):
x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
x = self.model.maxpool(x)
x = self.model.layer1(x)
x = self.model.layer2(x)
x = self.model.layer3(x)
x = self.model.layer4(x)
x = self.model.avgpool(x)
x = x.view(x.size(0), x.size(1))
x = self.classifier(x)
return x

更多细节在model.py中,里面还包含了其他的预训练模型以及对应的修改方法。

文章作者: GeYu
文章链接: https://nuistgy.github.io/2019/08/24/行人重识别代码实战(二)/
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 Yu's Blog