pytorchでU-Netやってみた話
論文を調べていたらU-Netにちょくちょく遭遇したので、pytorchでやってみたよっていう話です。
Qiitaにすごい綺麗にまとまっているものがある(しかもデータセット周りとか完全にそれを参考にしている)ので、ブログに書くか迷ったんですが、pytorchでちゃんとやった的な日本語の記事があんまり見つからなかったので、じゃあ書いちゃおっかなと思った次第です。
U-Netについて
セマンティックセグメンテーションを行うネットワークの1つです。
セマンティックセグメンテーションは、めっちゃ簡単にいうとピクセル単位のクラス分類問題で、出力として入力画像の各ピクセルが何に属しているかを表すカラーマップが得られます。
U-Netは以下のような構造をしています。
左側のcontracting pathではダウンサンプリングをして、右側のexpansive pathではアップサンプリングすることで最終的にセグメンテーションマップを得ます。その際にcontracting pathからexpansive pathに特徴量を結合することでコンテキスト情報を渡します。(他の人がまとめている説明とかを見ると、浅い畳み込み層の特徴量は空間情報を保持していることがポイントだ、という説明もあったんですが、論文をざっと読んだだけだと読み取れませんでした…)
U-NetはFCN(fully convolutional network)を改良したものです。論文によるとFCNからの"one important modification"はupsampling partの特徴量のチャネルの値が大きいため、コンテキスト情報を渡すことができることらしいです。(FCNのスタンダードな構造をきちんと調べていないのでいまいちピンと来ない。)
データセット
データセットはVOC2012というものを使いました。
U-NetでPascal VOC 2012の画像をSemantic Segmentationする (TensorFlow) - Qiita
この記事のデータセットの取り扱いをめちゃくちゃ参考にしたので、貼っておきます。
今回は面倒臭かったのでdata augmentationとかはしなかったです。(そのせいでテストデータに対する結果はいまいちでした。)
実装
U-Netの実装
U-Netの実装はgithubから拾ってきました。
これのBasic_blocks.pyとUNet.pyをそのままネットワークの実装として使いました。
実際のコードはこんな感じ。
def conv_block(in_dim, out_dim, act_fn): model = nn.Sequential( nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_dim), act_fn, ) return model def conv_trans_block(in_dim, out_dim, act_fn): model = nn.Sequential( nn.ConvTranspose2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(out_dim), act_fn, ) return model def maxpool(): pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) return pool def conv_block_2(in_dim, out_dim, act_fn): model = nn.Sequential( conv_block(in_dim, out_dim, act_fn), conv_block(out_dim, out_dim, act_fn), nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_dim), ) return model class Unet(nn.Module): def __init__(self, in_dim, out_dim, num_filter): super(Unet, self).__init__() self.in_dim = in_dim self.out_dim = out_dim self.num_filter = num_filter act_fn = nn.LeakyReLU(0.2, inplace=True) #print('Â¥n-----Initiating U-Net-----Â¥n') self.down_1 = conv_block_2(self.in_dim, self.num_filter, act_fn) self.pool_1 = maxpool() self.down_2 = conv_block_2(self.num_filter*1, self.num_filter*2, act_fn) self.pool_2 = maxpool() self.down_3 = conv_block_2(self.num_filter*2, self.num_filter*4, act_fn) self.pool_3 = maxpool() self.down_4 = conv_block_2(self.num_filter*4, self.num_filter*8, act_fn) self.pool_4 = maxpool() self.bridge = conv_block_2(self.num_filter*8, self.num_filter*16, act_fn) self.trans_1 = conv_trans_block(self.num_filter*16, self.num_filter*8, act_fn) self.up_1 = conv_block_2(self.num_filter*16, self.num_filter*8, act_fn) self.trans_2 = conv_trans_block(self.num_filter*8, self.num_filter*4, act_fn) self.up_2 = conv_block_2(self.num_filter*8, self.num_filter*4, act_fn) self.trans_3 = conv_trans_block(self.num_filter*4, self.num_filter*2, act_fn) self.up_3 = conv_block_2(self.num_filter*4, self.num_filter*2, act_fn) self.trans_4 = conv_trans_block(self.num_filter*2, self.num_filter*1, act_fn) self.up_4 = conv_block_2(self.num_filter*2, self.num_filter*1, act_fn) self.out = nn.Sequential( nn.Conv2d(self.num_filter, self.out_dim, 3, 1, 1), nn.LogSoftmax(dim=1), ) def forward(self,input): down_1 = self.down_1(input) pool_1 = self.pool_1(down_1) down_2 = self.down_2(pool_1) pool_2 = self.pool_2(down_2) down_3 = self.down_3(pool_2) pool_3 = self.pool_3(down_3) down_4 = self.down_4(pool_3) pool_4 = self.pool_4(down_4) bridge = self.bridge(pool_4) trans_1 = self.trans_1(bridge) concat_1 = torch.cat([trans_1,down_4],dim=1) up_1 = self.up_1(concat_1) trans_2 = self.trans_2(up_1) concat_2 = torch.cat([trans_2,down_3],dim=1) up_2 = self.up_2(concat_2) trans_3 = self.trans_3(up_2) concat_3 = torch.cat([trans_3,down_2],dim=1) up_3 = self.up_3(concat_3) trans_4 = self.trans_4(up_3) concat_4 = torch.cat([trans_4,down_1],dim=1) up_4 = self.up_4(concat_4) out = self.out(up_4) return out
論文の通りの実装ではないみたいで、パディングとかして入出力を同じサイズにしてる感じみたいです。
データセットの実装
組み込みではないデータセットを使うので、データセットも書いておきます。
データセットのディレクトリはこんな感じにしました。datasetディレクトリ以下にtrainとvalディレクトリがあり、それぞれに元画像のディレクトリ(original)とセグメント画像のディレクトリ(segment)がある感じです。
データセットのコードはこんな感じです。
class SegmentDataset(torch.utils.data.Dataset): def __init__(self, transform=None, train=True): self.transform = transform self.original_images = [] self.segmented_images = [] root = 'dataset' if train==True: original_image_path = os.path.join(root, 'train', 'original') segmented_image_path = os.path.join(root, 'train', 'segment') else: original_image_path = os.path.join(root, 'val', 'original') segmented_image_path = os.path.join(root, 'val', 'segment') original_image_files = os.listdir(original_image_path) segmented_image_files = os.listdir(segmented_image_path) for original_image_file in original_image_files: self.original_images.append(os.path.join(original_image_path, original_image_file)) self.segmented_images.append(os.path.join(segmented_image_path, original_image_file.split('.')[0]+'.png')) def __getitem__(self, index): original = self.original_images[index] segmented = self.segmented_images[index] with open(original, 'rb') as f: image = Image.open(f) image = image.convert('RGB') if self.transform is not None: image = self.transform(image) with open(segmented, 'rb') as f: label = Image.open(f) label.convert('P') label = np.asarray(label) label = np.where(label == 255, 22-1, label) """ segment_numpy = np.zeros([22, label.shape[0], label.shape[1]]) for c in range(22): segment_numpy[c] = np.where(label==c, 1, 0) """ #segment_numpy = torch.from_numpy(segment_numpy) return image, label def __len__(self): return len(self.original_images)
実際に学習する
実際に学習します。100epoch回すコードを書きます。こんな感じです。
data_transforms = transforms.Compose([ transforms.ToTensor() ]) train_dataset = SegmentDataset(data_transforms, train=True) test_dataset = SegmentDataset(data_transforms, train=False) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=1, shuffle=True) test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=False) device = 'cuda' if torch.cuda.is_available else 'cpu' print(device) net = Unet(3, 22, 64).to(device) criterion = nn.NLLLoss() optimizer = optim.Adam(net.parameters(), lr=0.001) num_epochs = 100 train_loss_list = [] train_acc_list = [] val_loss_list = [] val_acc_list = [] for epoch in range(num_epochs): train_loss = 0 val_loss = 0 #train net.train() with tqdm(train_loader, ncols=100) as pbar: for i, (images, labels) in enumerate(pbar): images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = net(images) loss = criterion(outputs, labels.long()) train_loss += loss.item() loss.backward() optimizer.step() avg_train_loss = train_loss / len(train_loader.dataset) #val net.eval() with torch.no_grad(): with tqdm(test_loader, ncols=100) as pbar: for images, labels in pbar: images = images.to(device) labels = labels.to(device) outputs = net(images) loss = criterion(outputs, labels.long()) val_loss += loss.item() avg_val_loss = val_loss / len(test_loader.dataset) print ('Epoch [{}/{}], Loss: {loss:.4f}, val_loss: {val_loss:.4f}' .format(epoch+1, num_epochs, loss=avg_train_loss, val_loss=avg_val_loss)) train_loss_list.append(avg_train_loss) val_loss_list.append(avg_val_loss) if epoch % 10 ==0: torch.save(net.state_dict(), 'param/param-'+str(epoch)+'.pth') train_loss_list.append(avg_train_loss) val_loss_list.append(avg_val_loss) with open('train_loss.pickle', 'wb') as f: pickle.dump(train_loss_list, f) with open('val_loss.pickle', 'wb') as f: pickle.dump(val_loss_list, f)
損失関数はnn.NLLLoss()
を使っています。ネットワークの実装の部分で、出力の最後にnn.LogSoftmax(dim=1)
を付け足しているんですが、nn.LogSoftmax(dim=1)→nn.NLLLoss()
とすることで2次元のデータに対してクロスエントロピー誤差を計算できるようになります。
そこら辺どうせ面倒くさいんだろうなあと思って最初はnn.MSELoss()
を使っていたんですがうまくいかなかったので調べたら意外と簡単でした。ちなみにググってたらnn.NLLLoss2d()
を使え!みたいな記事が出てきて、試したんですが「それもう使えないよ」ってpytorchに言われてほーんみたいな気持ちになりました。
結果
lossの落ち方を見てみます。
まだ学習のlossは落ちそうな感じです。ただ、テストのlossはバリバリに上がってしまっていたので、汎化はできてないのかなあって感じでした。data augmentationをしてないので、データ量が足りてなかったりもするんじゃないかと思います。
そこまでうまくいってないですが、いくつか実際の結果も見てみます。
まずは学習データに対する結果です。左から元画像、正解画像、推定画像です。
簡単なセグメンテーションはうまくいってますが、複雑なものは結構ひどい結果になっています。
次はテストデータに対する結果です。左から元画像、正解画像、推定画像です。
恣意的にうまくいってるものを多めに入れましたがこんなもんでした。
まとめ
U-Netをpytorchでやってみました。内容的には真新しさは微塵もありませんが、既存のデータセットをきちんと扱ったり、損失関数の設計をちゃんと調べたりと勉強になったのでよかったかなあと思っています。
参考文献
[1] Olaf Ronneberger, Philipp Fischer, Thomas Brox, U-Net: Convolutional Networks for Biomedical Image Segmentation, CoRR, 2015.