• 周五. 4月 26th, 2024

5G编程聚合网

5G时代下一个聚合的编程学习网

热门标签

DeeplabV3+训练自己的数据集(三)

admin

11月 28, 2021

模型训练及测试

一、在DeepLabv3+模型的基础上,主要需要修改以下两个文件

   data_generator.py

   train_utils.py

   (1)添加数据集描述

   在datasets/data_generator.py文件中,添加自己的数据集描述:
_CAMVID_INFORMATION = DatasetDescriptor(
    splits_to_sizes={
    'train': 1035,
    'val': 31,},
    num_classes=3,
    ignore_label=255, )
自己的数据集共有3个classes,算上了background。由于没有使用 ignore_label , 没有算上ignore_label

  (2)注册数据集

_DATASETS_INFORMATION = {
    'cityscapes': _CITYSCAPES_INFORMATION,
    'pascal_voc_seg': _PASCAL_VOC_SEG_INFORMATION,
    'ade20k': _ADE20K_INFORMATION,
    'camvid':_CAMVID_INFORMATION,
    # 'mydata':_MYDATA_INFORMATION,
    }

  (3)修改train_utils.py 

  对应的utils/train_utils.py中,将210行关于 exclude_list 的设置修改,作用是在使用预训练权重时候,不加载该 logit 层:

  

exclude_list = ['global_step','logits']
if not initialize_last_layer:
    exclude_list.extend(last_layers)

  如果想在DeepLab的基础上fifine-tune其他数据集, 可在deeplab/train.py中修改输入参数。

  一些选项:
    使用预训练的所有权重,设置initialize_last_layer=True
    只使用网络的backbone,设置initialize_last_layer=False和
    last_layers_contain_logits_only=False
    使用所有的预训练权重,除了logits。因为如果是自己的数据集,对应的classes不同(这个我们前面已经设置不加载logits),可设置initialize_last_layer=False和ast_layers_contain_logits_only=True
  这里使用的设置是:
  initialize_last_layer=False #157行
  last_layers_contain_logits_only=True #160行

二、网路训练

  (1)下载预训练模型

  下载地址:https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zoo.md  

  下载到deeplab目录下,然后解压:
  tar -zxvf deeplabv3_cityscapes_train_2018_02_06.tar.gz
  需要注意对应的解压文件目录为:
/lwh/models/research/deeplab/deeplabv3_cityscapes_train

  (2)类别不平衡修正

    blackboard分割项目案例中的数据集,因为是3分类问题,其中background占了非常大的比例,设置的
    权重比例为1,3,3,
    注意:权重的设置对最终的分割性能有影响。权重的设置因数据集而异。    
    在common.py的145行修改权重如下:
  

flags.DEFINE_multi_float(
    'label_weights', [1.0,3.0,3.0],
    'A list of label weights, each element represents the weight for the label '
    'of its index, for example, label_weights = [0.1, 0.5] means the weight '
    'for label 0 is 0.1 and the weight for label 1 is 0.5. If set as None, all '
    'the labels have the same weight 1.0.')

  (3)训练

    注意如下几个参数:
    train_logdir: 训练产生的文件存放位置
    dataset_dir: 数据集的TFRecord文件
    dataset:设置为在data_generator.py文件设置的数据集名称
    
    在自己的数据集上的训练指令如下:
    在目录 ~/models/research/deeplab下执行
  

python train.py   --training_number_of_steps=30000  --train_split="train"  --model_variant="xception_65" 
--atrous_rates=6 --atrous_rates=12 --atrous_rates=18 --output_stride=16 --decoder_output_stride=4
--train_crop_size=801,801 --train_batch_size=2 --dataset="camvid"
--tf_initial_checkpoint='/lwh/models/research/deeplab/deeplabv3_cityscapes_train/model.ckpt'
--train_logdir='/lwh/models/research/deeplab/exp/blackboard_train/train'
--dataset_dir='/lwh/models/research/deeplab/datasets/blackboard/tfrecord'

    设置train_crop_size原则:

    output_stride * k + 1, where k is an integer. For example, we have 321×321,513×513,801×801

  (4)模型导出

  

python export_model.py 
    --logtostderr 
    --checkpoint_path="/lwh/models/research/deeplab/exp/blackboard_train/train/model.ckpt-30000" 
    --export_path="/lwh/models/research/deeplab/exp/blackboard_train/train/frozen_inference_graph.pb"  
    --model_variant="xception_65"  
    --atrous_rates=6  
    --atrous_rates=12  
    --atrous_rates=18   
    --output_stride=16  
    --decoder_output_stride=4  
    --num_classes=3 
    --crop_size=1080 
    --crop_size=1920 
    --inference_scales=1.0

  注意几点:

  –checkpoint_path 为自己模型保存的路径

  –export_path 模型导出保存的路径

  –num_classes=3 自己数据的类别数包含背景 

       –crop_size=1080  第一个为模型要求输入的高h

       –crop_size=1920 第一个为模型要求输入的宽w

三、模型测试

  直接上代码

  

# !--*-- coding:utf-8 --*--

# Deeplab Demo

import os
import tarfile

from matplotlib import gridspec
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import tempfile
from six.moves import urllib

import tensorflow as tf


class DeepLabModel(object):
    """
  加载 DeepLab 模型;
  推断 Inference
  """
    INPUT_TENSOR_NAME = 'ImageTensor:0'
    OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'
    INPUT_SIZE = 1920
    FROZEN_GRAPH_NAME = 'frozen_inference_graph'

    def __init__(self, tarball_path):
        """
    Creates and loads pretrained deeplab model.
    """
        self.graph = tf.Graph()

        graph_def = None
        graph_def = tf.GraphDef.FromString(open(tarball_path, 'rb').read())

        if graph_def is None:
            raise RuntimeError('Cannot find inference graph in tar archive.')

        with self.graph.as_default():
            tf.import_graph_def(graph_def, name='')
        self.sess = tf.Session(graph=self.graph)

    def run(self, image):
        """
    Runs inference on a single image.
    Args:
    image: A PIL.Image object, raw input image.
    Returns:
    resized_image: RGB image resized from original input image.
    seg_map: Segmentation map of `resized_image`.
    """
        width, height = image.size
        resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)
        target_size = (int(resize_ratio * width), int(resize_ratio * height))
        target_size = (1920,1080)
        resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS)
        print(resized_image)
        batch_seg_map = self.sess.run(self.OUTPUT_TENSOR_NAME,
                                      feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]})
        seg_map = batch_seg_map[0]
        return resized_image, seg_map


def create_pascal_label_colormap():
    """
  Creates a label colormap used in PASCAL VOC segmentation benchmark.
  Returns:
      A Colormap for visualizing segmentation results.
  """
    colormap = np.zeros((256, 3), dtype=int)
    ind = np.arange(256, dtype=int)

    for shift in reversed(range(8)):
        for channel in range(3):
            colormap[:, channel] |= ((ind >> channel) & 1) << shift
        ind >>= 3

    return colormap


def label_to_color_image(label):
    """
  Adds color defined by the dataset colormap to the label.
  Args:
      label: A 2D array with integer type, storing the segmentation label.
  Returns:
      result: A 2D array with floating type. The element of the array
      is the color indexed by the corresponding element in the input label
      to the PASCAL color map.
  Raises:
      ValueError: If label is not of rank 2 or its value is larger than color
      map maximum entry.
  """
    if label.ndim != 2:
        raise ValueError('Expect 2-D input label')

    colormap = create_pascal_label_colormap()

    if np.max(label) >= len(colormap):
        raise ValueError('label value too large.')

    return colormap[label]


def vis_segmentation(image, seg_map):
    """Visualizes input image, segmentation map and overlay view."""
    plt.figure(figsize=(15, 5))
    grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1])

    plt.subplot(grid_spec[0])
    plt.imshow(image)
    plt.axis('off')
    plt.title('input image')

    plt.subplot(grid_spec[1])
    seg_image = label_to_color_image(seg_map).astype(np.uint8)
    plt.imshow(seg_image)
    plt.axis('off')
    plt.title('segmentation map')

    plt.subplot(grid_spec[2])
    plt.imshow(image)
    plt.imshow(seg_image, alpha=0.7)
    plt.axis('off')
    plt.title('segmentation overlay')

    unique_labels = np.unique(seg_map)
    ax = plt.subplot(grid_spec[3])
    plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation='nearest')
    ax.yaxis.tick_right()
    plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
    plt.xticks([], [])
    ax.tick_params(width=0.0)
    plt.grid('off')
    plt.show()

LABEL_NAMES = np.asarray(
    ['background', 'blackboard','screen'])
# LABEL_NAMES = np.asarray(
#     ['background', 'blackboard','screen'])

FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)



download_path =  r"D:python_projectdeeplabv3+lackboard_v2.pb"

MODEL = DeepLabModel(download_path)
print('model loaded successfully!')


##
def run_visualization(imagefile):
    """
  DeepLab 语义分割,并可视化结果.
  """
    orignal_im = Image.open(imagefile)
    print('running deeplab on image %s...' % imagefile)
    resized_im, seg_map = MODEL.run(orignal_im)
    print(seg_map.shape)

    vis_segmentation(resized_im, seg_map)


images_dir = r'D:python_projectdeeplabv3+	est_img'  # 测试图片目录所在位置
images = sorted(os.listdir(images_dir))
for imgfile in images:
    run_visualization(os.path.join(images_dir, imgfile))

print('Done.')

  需要注意的两点:

  1.images_dir 修改为自己存图片的dir

  2.INPUT_SIZE = 1920修改自己图片的hw最大的一个

  测试结果展示

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注