본문 바로가기
  • Deep dive into Learning
  • Deep dive into Optimization
  • Deep dive into Deep Learning
Deep dive into Pytorch

Pytorch 10 : Transfer learning

by Sapiens_Nam 2023. 8. 10.

 

오늘은 image classification을 'transfer learning'을 활용해 파이토치로 구현하는 코드에 대해 살펴보고자 한다.

오늘날 CNN 모델을 random initialization을 하여 'scratch'로 학습하는 것은 상당히 드문 일이다. 대다수의 경우에는 

이미 학습되어진 (pre-trained) 모델을 가져다가 파라미터를 살짝 조정하는 방식으로의 미세 조정 (Fine-tuning) 을 진행하거나, CNN에서 특징을 추출하는 부분인 Convolution layer 부분은 파라미터를 고정시켜놓고 (Freeze), 마지막에 분류를 진행하는 fully-connected layer 부분만 바꾸어서 재학습하는 방식등을 사용한다.

 

이때 사전 학습된 CNN은 주로 ImageNet 같은 상당히 큰 데이터셋으로 학습이 되어졌고, 이때의 파라미터가 초기화된 파라미터값이 되는 것이다. 그리고 우리는 우리의 자체 데이터를 가지고 이 파라미터에서 약간의 미세 조정을 통해 학습된 모델을 우리가 원하는 Task에 사용하는 것이다.

이는 처음에 random하게 initialization을 하여서 학습을 처음부터 진행하는 것보다 상당히 효율적이고 좋은 성능을 보여주는 방법이다.

 

 

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_schedular
import torchvision
from torchvision import datasets, models, transforms

 

 

우리가 가지고 있는 데이터 (custom dataset)는 두 개의 클래스가 있다고 가정하자.

그리고 ImageNet으로 사전 학습된 ResNet 18을 사용한다고 할 때, 이 모델의 fully connected layer 부분을 수정해야 한다.

 

model_ft = models.resnet18(weights = 'IMAGENET1K_V1')
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2)
model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(model_ft.parameters(), lr = 0.001, momentum = 0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)

 

ImageNet은 1000개의 클래스를 가진 large dataset이고 이 데이터셋으로 사전 학습된 resnet18의 output layer는 1,000개의 노드로 이뤄져 있다. 우리가 가진 데이터셋은 2개의 클래스라고 가정하였으므로, 우리는 이 부분에 layer를 하나 추가하여 2개의 노드로 이뤄진 output layer로 바꾸어주었다.

 

이후 Fine-tuning 학습을 위한 SGD + Momentum optimizer와 learning rate scheduler를 선언해주었다.

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    # Create a temporary directory to save training checkpoints
    with TemporaryDirectory() as tempdir:
        best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')

        torch.save(model.state_dict(), best_model_params_path)
        best_acc = 0.0

        for epoch in range(num_epochs):
            print(f'Epoch {epoch}/{num_epochs - 1}')
            print('-' * 10)

            # Each epoch has a training and validation phase
            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()  # Set model to training mode
                else:
                    model.eval()   # Set model to evaluate mode

                running_loss = 0.0
                running_corrects = 0

                # Iterate over data.
                for inputs, labels in dataloaders[phase]:
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    # zero the parameter gradients
                    optimizer.zero_grad()

                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)

                        # backward + optimize only if in training phase
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    # statistics
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)
                if phase == 'train':
                    scheduler.step()

                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc = running_corrects.double() / dataset_sizes[phase]

                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

                # deep copy the model
                if phase == 'val' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), best_model_params_path)

            print()

        time_elapsed = time.time() - since
        print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
        print(f'Best val Acc: {best_acc:4f}')

        # load best model weights
        model.load_state_dict(torch.load(best_model_params_path))
    return model
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs = 25)

 

이렇게 하면 우리는 ImageNet으로 사전 학습된 모델의 가중치를 시작 지점으로 하여 우리의 데이터셋으로 모델을 조금만 학습시켜도 충분히 좋은 결과를 얻어낼 수 있게 된다.

 

다음은 Convolution layer 부분의 파라미터는 고정시킨 채로 마지막 layer 파라미터만 학습하는 방식의 코드를 보도록 하자.

여기서 핵심은 'requires_grad = False'로 설정하여서 파라미터에 대한 gradient를 계산하지 않도록 하는 것이다.

 

model_conv = torchvision.models.resnet18(weights='IMAGENET1K_V1')
for param in model_conv.parameters():
    param.requires_grad = False

# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)

model_conv = model_conv.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that only parameters of final layer are being optimized as
# opposed to before.
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

 

위의 for 문을 살펴보면 모델의 convolution layer 부분의 파라미터들을 requires_grad = False로 선언해주는데, 이는 loss.backward()를 통해서 graident가 계산되지 않도록 하기 위함이다.

이를 우리는 파라미터를 'freezing'한다고 이야기한다.

이렇게 한 후, 모델을 Fine-tuninig하면 convolution layer 부분은 파라미터가 고정되어져 있고, fully connected layer 부분의 파라미터만 학습이 이뤄진다.

 

728x90

'Deep dive into Pytorch' 카테고리의 다른 글

Pytorch 9 : Distributed training  (0) 2023.08.05
Pytorch 8 : Train과 Test  (0) 2023.07.26
Pytorch 7 : Tensor 심화  (0) 2023.07.23
Pytorch 6 : Implement CNN  (0) 2023.07.20
Pytorch 5 : Save and Load  (0) 2023.07.17

댓글