demo-ml-pennfudanped

Форк
0
/
maskrcnn_model.py 
23 строки · 1.0 Кб
1
import torchvision
2
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
3
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
4

5

6
def build_maskrsnn_model(for_num_of_classes, weights):
7
    # load an instance segmentation model pre-trained on COCO
8
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=weights)
9

10
    # get number of input features for the classifier
11
    in_features = model.roi_heads.box_predictor.cls_score.in_features
12
    # replace the pre-trained head with a new one
13
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, for_num_of_classes)
14

15
    # now get the number of input features for the mask classifier
16
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
17
    hidden_layer = 256
18
    # and replace the mask predictor with a new one
19
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
20
                                                       hidden_layer,
21
                                                       for_num_of_classes)
22

23
    return model
24

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.