demo-ml-pennfudanped
/
maskrcnn_model.py
23 строки · 1.0 Кб
1import torchvision2from torchvision.models.detection.faster_rcnn import FastRCNNPredictor3from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor4
5
6def build_maskrsnn_model(for_num_of_classes, weights):7# load an instance segmentation model pre-trained on COCO8model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=weights)9
10# get number of input features for the classifier11in_features = model.roi_heads.box_predictor.cls_score.in_features12# replace the pre-trained head with a new one13model.roi_heads.box_predictor = FastRCNNPredictor(in_features, for_num_of_classes)14
15# now get the number of input features for the mask classifier16in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels17hidden_layer = 25618# and replace the mask predictor with a new one19model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,20hidden_layer,21for_num_of_classes)22
23return model24