庚午里藻の日記

見た映画とかアニメの備忘録にしたり、パソコンいじったことのメモにしたり

pytorchでU-Netやってみた話

論文を調べていたらU-Netにちょくちょく遭遇したので、pytorchでやってみたよっていう話です。

Qiitaにすごい綺麗にまとまっているものがある(しかもデータセット周りとか完全にそれを参考にしている)ので、ブログに書くか迷ったんですが、pytorchでちゃんとやった的な日本語の記事があんまり見つからなかったので、じゃあ書いちゃおっかなと思った次第です。

U-Netについて

セマンティックセグメンテーションを行うネットワークの1つです。

セマンティックセグメンテーションは、めっちゃ簡単にいうとピクセル単位のクラス分類問題で、出力として入力画像の各ピクセルが何に属しているかを表すカラーマップが得られます。

f:id:kanoeuma_310mo:20200314230643j:plain:w200f:id:kanoeuma_310mo:20200314230656p:plain:w200
こんな感じ

U-Netは以下のような構造をしています。

f:id:kanoeuma_310mo:20200314224246p:plain
[1]より引用

左側のcontracting pathではダウンサンプリングをして、右側のexpansive pathではアップサンプリングすることで最終的にセグメンテーションマップを得ます。その際にcontracting pathからexpansive pathに特徴量を結合することでコンテキスト情報を渡します。(他の人がまとめている説明とかを見ると、浅い畳み込み層の特徴量は空間情報を保持していることがポイントだ、という説明もあったんですが、論文をざっと読んだだけだと読み取れませんでした…)

U-NetはFCN(fully convolutional network)を改良したものです。論文によるとFCNからの"one important modification"はupsampling partの特徴量のチャネルの値が大きいため、コンテキスト情報を渡すことができることらしいです。(FCNのスタンダードな構造をきちんと調べていないのでいまいちピンと来ない。)

データセット

データセットはVOC2012というものを使いました。

U-NetでPascal VOC 2012の画像をSemantic Segmentationする (TensorFlow) - Qiita

この記事のデータセットの取り扱いをめちゃくちゃ参考にしたので、貼っておきます。

今回は面倒臭かったのでdata augmentationとかはしなかったです。(そのせいでテストデータに対する結果はいまいちでした。)

実装

U-Netの実装

U-Netの実装はgithubから拾ってきました。

github.com

これのBasic_blocks.pyとUNet.pyをそのままネットワークの実装として使いました。

実際のコードはこんな感じ。

def conv_block(in_dim, out_dim, act_fn):
    model = nn.Sequential(
        nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_dim),
        act_fn,
    )
    return model

def conv_trans_block(in_dim, out_dim, act_fn):
    model = nn.Sequential(
        nn.ConvTranspose2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
        nn.BatchNorm2d(out_dim),
        act_fn,
    )
    return model

def maxpool():
    pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
    return pool

def conv_block_2(in_dim, out_dim, act_fn):
    model = nn.Sequential(
        conv_block(in_dim, out_dim, act_fn),
        conv_block(out_dim, out_dim, act_fn),
        nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_dim),
    )
    return model

class Unet(nn.Module):
    def __init__(self, in_dim, out_dim, num_filter):
        super(Unet, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.num_filter = num_filter
        act_fn = nn.LeakyReLU(0.2, inplace=True)

        #print('Â¥n-----Initiating U-Net-----Â¥n')

        self.down_1 = conv_block_2(self.in_dim, self.num_filter, act_fn)
        self.pool_1 = maxpool()
        self.down_2 = conv_block_2(self.num_filter*1, self.num_filter*2, act_fn)
        self.pool_2 = maxpool()
        self.down_3 = conv_block_2(self.num_filter*2, self.num_filter*4, act_fn)
        self.pool_3 = maxpool()
        self.down_4 = conv_block_2(self.num_filter*4, self.num_filter*8, act_fn)
        self.pool_4 = maxpool()

        self.bridge = conv_block_2(self.num_filter*8, self.num_filter*16, act_fn)

        self.trans_1 = conv_trans_block(self.num_filter*16, self.num_filter*8, act_fn)
        self.up_1 = conv_block_2(self.num_filter*16, self.num_filter*8, act_fn)
        self.trans_2 = conv_trans_block(self.num_filter*8, self.num_filter*4, act_fn)
        self.up_2 = conv_block_2(self.num_filter*8, self.num_filter*4, act_fn)
        self.trans_3 = conv_trans_block(self.num_filter*4, self.num_filter*2, act_fn)
        self.up_3 = conv_block_2(self.num_filter*4, self.num_filter*2, act_fn)
        self.trans_4 = conv_trans_block(self.num_filter*2, self.num_filter*1, act_fn)
        self.up_4 = conv_block_2(self.num_filter*2, self.num_filter*1, act_fn)

        self.out = nn.Sequential(
            nn.Conv2d(self.num_filter, self.out_dim, 3, 1, 1),
            nn.LogSoftmax(dim=1),
        )

    def forward(self,input):
        down_1 = self.down_1(input)
        pool_1 = self.pool_1(down_1)
        down_2 = self.down_2(pool_1)
        pool_2 = self.pool_2(down_2)
        down_3 = self.down_3(pool_2)
        pool_3 = self.pool_3(down_3)
        down_4 = self.down_4(pool_3)
        pool_4 = self.pool_4(down_4)
    
        bridge = self.bridge(pool_4)
    
        trans_1 = self.trans_1(bridge)
        concat_1 = torch.cat([trans_1,down_4],dim=1)
        up_1 = self.up_1(concat_1)
        trans_2 = self.trans_2(up_1)
        concat_2 = torch.cat([trans_2,down_3],dim=1)
        up_2 = self.up_2(concat_2)
        trans_3 = self.trans_3(up_2)
        concat_3 = torch.cat([trans_3,down_2],dim=1)
        up_3 = self.up_3(concat_3)
        trans_4 = self.trans_4(up_3)
        concat_4 = torch.cat([trans_4,down_1],dim=1)
        up_4 = self.up_4(concat_4)
    
        out = self.out(up_4)
    
        return out

論文の通りの実装ではないみたいで、パディングとかして入出力を同じサイズにしてる感じみたいです。

データセットの実装

組み込みではないデータセットを使うので、データセットも書いておきます。

f:id:kanoeuma_310mo:20200315001706p:plain:w300

データセットディレクトリはこんな感じにしました。datasetディレクトリ以下にtrainとvalディレクトリがあり、それぞれに元画像のディレクトリ(original)とセグメント画像のディレクトリ(segment)がある感じです。

データセットのコードはこんな感じです。

class SegmentDataset(torch.utils.data.Dataset):
    def __init__(self, transform=None, train=True):
        self.transform = transform
        self.original_images = []
        self.segmented_images = []

        root = 'dataset'

        if train==True:
            original_image_path = os.path.join(root, 'train', 'original')
            segmented_image_path = os.path.join(root, 'train', 'segment')
        else:
            original_image_path = os.path.join(root, 'val', 'original')
            segmented_image_path = os.path.join(root, 'val', 'segment')

        original_image_files = os.listdir(original_image_path)
        segmented_image_files = os.listdir(segmented_image_path)

        for original_image_file in original_image_files:
            self.original_images.append(os.path.join(original_image_path, original_image_file))
            self.segmented_images.append(os.path.join(segmented_image_path, original_image_file.split('.')[0]+'.png'))


    def __getitem__(self, index):
        original = self.original_images[index]
        segmented = self.segmented_images[index]

        with open(original, 'rb') as f:
            image = Image.open(f)
            image = image.convert('RGB')
        if self.transform is not None:
            image = self.transform(image)

        with open(segmented, 'rb') as f:
            label = Image.open(f)
            label.convert('P')
            label = np.asarray(label)
            label = np.where(label == 255, 22-1, label)
            """
            segment_numpy = np.zeros([22, label.shape[0], label.shape[1]])
            for c in range(22):
                segment_numpy[c] = np.where(label==c, 1, 0)
            """
            #segment_numpy = torch.from_numpy(segment_numpy)

        return image, label

    def __len__(self):
        return len(self.original_images)

実際に学習する

実際に学習します。100epoch回すコードを書きます。こんな感じです。

data_transforms = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = SegmentDataset(data_transforms, train=True)
test_dataset = SegmentDataset(data_transforms, train=False)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=1, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)

device = 'cuda' if torch.cuda.is_available else 'cpu'
print(device)
net = Unet(3, 22, 64).to(device)

criterion = nn.NLLLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

num_epochs = 100

train_loss_list = []
train_acc_list = []
val_loss_list = []
val_acc_list = []

for epoch in range(num_epochs):
    train_loss = 0
    val_loss = 0
    
    #train
    net.train()
    with tqdm(train_loader, ncols=100) as pbar:
        for i, (images, labels) in enumerate(pbar):
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(outputs, labels.long())
            train_loss += loss.item()
            loss.backward()
            optimizer.step()
    
    avg_train_loss = train_loss / len(train_loader.dataset)
    
    #val
    net.eval()
    with torch.no_grad():
        with tqdm(test_loader, ncols=100) as pbar:
            for images, labels in pbar:
                images = images.to(device)
                labels = labels.to(device)
                outputs = net(images)
                loss = criterion(outputs, labels.long())
                val_loss += loss.item()
    avg_val_loss = val_loss / len(test_loader.dataset)
    
    print ('Epoch [{}/{}], Loss: {loss:.4f}, val_loss: {val_loss:.4f}' 
                   .format(epoch+1, num_epochs, loss=avg_train_loss, val_loss=avg_val_loss))
    train_loss_list.append(avg_train_loss)
    val_loss_list.append(avg_val_loss)

    if epoch % 10 ==0:
        torch.save(net.state_dict(), 'param/param-'+str(epoch)+'.pth')
    train_loss_list.append(avg_train_loss)
    val_loss_list.append(avg_val_loss)

with open('train_loss.pickle', 'wb') as f:
    pickle.dump(train_loss_list, f)
with open('val_loss.pickle', 'wb') as f:
    pickle.dump(val_loss_list, f)

損失関数はnn.NLLLoss()を使っています。ネットワークの実装の部分で、出力の最後にnn.LogSoftmax(dim=1)を付け足しているんですが、nn.LogSoftmax(dim=1)→nn.NLLLoss()とすることで2次元のデータに対してクロスエントロピー誤差を計算できるようになります。

そこら辺どうせ面倒くさいんだろうなあと思って最初はnn.MSELoss()を使っていたんですがうまくいかなかったので調べたら意外と簡単でした。ちなみにググってたらnn.NLLLoss2d()を使え!みたいな記事が出てきて、試したんですが「それもう使えないよ」ってpytorchに言われてほーんみたいな気持ちになりました。

結果

lossの落ち方を見てみます。

f:id:kanoeuma_310mo:20200315013353p:plain:w600

まだ学習のlossは落ちそうな感じです。ただ、テストのlossはバリバリに上がってしまっていたので、汎化はできてないのかなあって感じでした。data augmentationをしてないので、データ量が足りてなかったりもするんじゃないかと思います。

そこまでうまくいってないですが、いくつか実際の結果も見てみます。

まずは学習データに対する結果です。左から元画像、正解画像、推定画像です。

f:id:kanoeuma_310mo:20200315013817j:plain:w200f:id:kanoeuma_310mo:20200315013825p:plain:w200f:id:kanoeuma_310mo:20200315013833p:plain:w200

f:id:kanoeuma_310mo:20200315014155j:plain:w200f:id:kanoeuma_310mo:20200315014152p:plain:w200f:id:kanoeuma_310mo:20200315014159p:plain:w200

f:id:kanoeuma_310mo:20200315014324j:plain:w200f:id:kanoeuma_310mo:20200315014321p:plain:w200f:id:kanoeuma_310mo:20200315014326p:plain:w200

簡単なセグメンテーションはうまくいってますが、複雑なものは結構ひどい結果になっています。

次はテストデータに対する結果です。左から元画像、正解画像、推定画像です。

f:id:kanoeuma_310mo:20200315022653j:plain:w200f:id:kanoeuma_310mo:20200315022656p:plain:w200f:id:kanoeuma_310mo:20200315022650p:plain:w200

f:id:kanoeuma_310mo:20200315022816j:plain:w200f:id:kanoeuma_310mo:20200315022819p:plain:w200f:id:kanoeuma_310mo:20200315022812p:plain:w200

f:id:kanoeuma_310mo:20200315022901j:plain:w200f:id:kanoeuma_310mo:20200315022903p:plain:w200f:id:kanoeuma_310mo:20200315022858p:plain:w200

恣意的にうまくいってるものを多めに入れましたがこんなもんでした。

まとめ

U-Netをpytorchでやってみました。内容的には真新しさは微塵もありませんが、既存のデータセットをきちんと扱ったり、損失関数の設計をちゃんと調べたりと勉強になったのでよかったかなあと思っています。

参考文献

[1] Olaf Ronneberger, Philipp Fischer, Thomas Brox, U-Net: Convolutional Networks for Biomedical Image Segmentation, CoRR, 2015.