paper
Abstract
- 기존 CNN은 한정된 자원에서 개발하고, 자원이 충분해지면 성능 개선을 위한 scale up을 진행했다.
- 본 연구에서는 model scaling에 대한 연구를 진행했고, depth/width/resolution이 균형을 이루는 network가 더 나은 성능을 보이는것을 확인했다.
- 이 연구에 기반해서, 'Compound Scaling Method'라는 새로운 scaling 방법을 제안한다.
- MobileNets, ResNet 모델에서 해당 방법을 통한 성능 향상을 입증했다.
Introduction
- 이전 연구에서는, depth/width/resolution 중 하나만 scale하는 것이 주를 이루었다.
- 이 연구에서는 depth/width/resolution 세 가지를 균등하게 scale 한다.
Compound Model Scaling
- depth : 네트워크가 깊어질수록 복잡한 feature를 잘 잡아내지만, vanishing gradient의 문제로 학습시키기 어려워진다.
- width : layer의 width가 커질수록 정확도가 높아지지만, 계산량이 증가한다.
- resolution : fine-grained(세부적인) feature를 잡아내기 쉽지만, accuracy gain(정확도 향상)이 더뎌진다.
- width/depth/resolution을 직접 tuning하던 것을, compound coefficient를 사용해 규칙에 의거해 scale한다.
- α, β, γ 는 작은 grid search로 결정할 수 있다.
- width(총 2번 연산)와 resolution(너비와 높이)은 flops가 4배가 늘어난다.
EfficientNet Architecture
- scaling은 baseline network의 layer operation을 변경하지 않기때문에, 좋은 baseline 선택이 중요하다.ㅌ
- 특히 큰 모델의 경우, α,β,γ를 직접 찾아 적용하면 성능이 좋아지지만 비용이 증가한다.
-
- 로 고정한 뒤, resource가 두 배로 있다고 가정하고 를 찾음
- 이번엔 를 고정하고 값을 바꿔서 baseline network를 scale up 함
Experiments
- MobileNet, ResNet을 scale up한 결과, 성능 향상을 보였다.
- ImageNet에서 적은 parameter로 더 높은 accuracy를 보였다.
Disscussion
- compound scaling한 모델이 relative region을 더 잘 집중한다.
Conclusion
- model scale up에서 width/depth/resolution의 균형이 중요하다.
- 성능향상을 위한 효과적인 compound scaling을 제안한다.
- 전이학습에서도 좋은 성능을 보인다.
Code
- Swish 활성화함수
class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
- SE Block (squeeze-and-excitation)
class SEBlock(nn.Module):
def __init__(self, in_channels, reduction=4):
super(SEBlock, self).__init__()
self.fc1 = nn.Conv2d(in_channels, in_channels//reduction, kernel_size=1)
self.fc2 = nn.Conv2d(in_channels//reduction, in_channels, kernel_size=1)
def forward(self, x):
squeeze = F.adaptive_avg_pool2d(x, 1)
excitation = torch.sigmoid(self.fc2(F.relu(self.fc1(squeeze))))
return x * excitation
- MBConv Block
class MBConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, expand_ratio, stride, se_ratio):
super(MBConvBlock, self).__init__()
self.stride = stride
hidden_dim = in_channels * expand_ratio
self.expand = nn.Conv2d(in_channels, hidden_dim, kernel_size=1, bias=False)
self.bn0 = nn.BatchNorm2d(hidden_dim)
self.depthwise = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=self.stride, padding=1, groups=hidden_dim, bias=False)
self.bn1 = nn.BatchNorm2d(hidden_dim)
self.se = SEBlock(hidden_dim, reduction=int(1/se_ratio))
self.project = nn.Conv2d(hidden_dim, out_channels, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.swish = Swish()
def forward(self, x):
residual = x
x = self.swish(self.bn0(self.expand(x)))
x = self.swish(self.bn1(self.depthwise(x)))
x = self.se(x)
x = self.bn2(self.project(x))
if self.stride == 1 and residual.size() == x.size():
x += residual
return x
- EfficientNet Class
class EfficientNet(nn.Module):
def __init__(self, num_classes=1000):
super(EfficientNet, self).__init__()
# initial
self.conv_stem = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)
self.bn0 = nn.BatchNorm2d(32)
self.swish = Swish()
# MBConv blocks
self.blocks = nn.Sequential(
MBConvBlock(32, 16, expand_ratio=1, stride=1, se_ratio=0.25),
MBConvBlock(16, 24, expand_ratio=6, stride=2, se_ratio=0.25),
MBConvBlock(24, 40, expand_ratio=6, stride=2, se_ratio=0.25),
MBConvBlock(40, 80, expand_ratio=6, stride=2, se_ratio=0.25),
MBConvBlock(80, 112, expand_ratio=6, stride=1, se_ratio=0.25),
MBConvBlock(112, 192, expand_ratio=6, stride=2, se_ratio=0.25),
MBConvBlock(192, 320, expand_ratio=6, stride=1, se_ratio=0.25),
)
# last layers
self.conv_head = nn.Conv2d(320, 1280, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(1280)
self.fc = nn.Linear(1280, num_classes)
def forward(self, x):
x = self.swish(self.bn0(self.conv_stem(x))) # 모델의 줄기 (시작)
x = self.blocks(x)
x = self.swish(self.bn1(self.conv_head(x))) # 모델의 끝 부분
x = F.adaptive_avg_pool2d(x, 1)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
'ML_DL > 딥러닝 공부하기' 카테고리의 다른 글
Ollama 설치 및 Llama3.1 모델 사용 (0) | 2024.08.14 |
---|---|
텍스트 임베딩 해보기 (0) | 2024.08.10 |
validation set의 data augmentation (0) | 2024.05.04 |
머신러닝 VS 딥러닝 (0) | 2024.04.17 |
Word Embedding (0) | 2024.01.08 |