lama
17 строк · 557.0 Байт
1import torch2import torch.nn as nn3
4class DepthWiseSeperableConv(nn.Module):5def __init__(self, in_dim, out_dim, *args, **kwargs):6super().__init__()7if 'groups' in kwargs:8# ignoring groups for Depthwise Sep Conv9del kwargs['groups']10
11self.depthwise = nn.Conv2d(in_dim, in_dim, *args, groups=in_dim, **kwargs)12self.pointwise = nn.Conv2d(in_dim, out_dim, kernel_size=1)13
14def forward(self, x):15out = self.depthwise(x)16out = self.pointwise(out)17return out