【实例分割|e2ec】环境配置与模型训练推理

阿里云国内75折 回扣 微信号:monov8
阿里云国际,腾讯云国际,低至75折。AWS 93折 免费开户实名账号 代冲值 优惠多多 微信号:monov8 飞机:@monov6

文章目录

项目来源

环境配置

git clone https://github.com/zhang-tao-whu/e2ec.git

# Set up the python environment
pip install Cython==0.28.2
pip install -r requirements.txt

# Compile cuda extensions
ROOT=/path/to/e2ec
cd $ROOT/network/backbone/DCNv2-master
# please check your cuda version and modify the cuda version in the command
export CUDA_HOME="/usr/local/cuda-11.0"
bash ./make.sh

模型训练

  • 训练命令
nohup python train_net.py coco_finetune --bs 12 --type finetune --checkpoint data/model/model_coco.pth
  • 修改配置文件 coco_finetune.py
from .base import commen, data, model, train, test
import numpy as np

data.mean = np.array([0.44726229, 0.43802511, 0.27905645],
                    dtype=np.float32).reshape(1, 1, 3)
std = np.array([0.22784984, 0.21254292, 0.16168552],
                   dtype=np.float32).reshape(1, 1, 3)

scale = np.array([640, 480])
input_w, input_h = (640, 480)

model.heads['ct_hm'] = 1

train.optimizer = {'name': 'sgd', 'lr': 1e-4, 'weight_decay': 1e-4,
                   'milestones': [150, ], 'gamma': 0.1}
train.batch_size = 12
train.epoch = 160
train.dataset = 'coco_train'

test.dataset = 'coco_val'

class config(object):
    commen = commen
    data = data
    model = model
    train = train
    test = test

模型推理

图像预测

  • 图像预测命令
python visualize.py coco_finetune data/coco/train2017 \
--checkpoint data/model/159.pth --with_nms True --output_dir data/output_train

视频预测

  • 视频预测命令
python visualize_video.py coco_finetune --video data/dark_1280_960.mp4 \
--checkpoint data/model/159.pth --with_nms True --output_path data/output_video.mp4

Step 1: 修改 Dataloader

  • 新建 e2ec/dataset/demo_dataset_video.py 文件
'''
Auther: zth
Date: 2022-08-19 11:43:43
LastEditTime: 2022-08-19 14:46:40
Description: 
'''
import os
import cv2
import numpy as np
from .train.utils import augment
import torch.utils.data as data


class DatasetVideo(data.Dataset):
    def __init__(self, data_root, cfg):
        super(DatasetVideo, self).__init__()
        self.data_root = data_root
        self.split = 'test'
        self.cfg = cfg
        self.video = cv2.VideoCapture(data_root)

    def read_video_frame(self, index):
        video = cv2.VideoCapture(self.data_root)
        video.set(cv2.CAP_PROP_POS_FRAMES, index)
        success, frame = video.read()
        video.release()
        return frame

    def __getitem__(self, index):
        img = self.read_video_frame(index)

        orig_img, inp, trans_input, trans_output, flipped, center, scale, inp_out_hw = \
            augment(
                img, self.split,
                self.cfg.data.data_rng, self.cfg.data.eig_val, self.cfg.data.eig_vec,
                self.cfg.data.mean, self.cfg.data.std, self.cfg.commen.down_ratio,
                self.cfg.data.input_h, self.cfg.data.input_w, self.cfg.data.scale_range,
                self.cfg.data.scale, self.cfg.test.test_rescale, self.cfg.data.test_scale
            )

        ret = {'inp': inp}

        meta = {
            'center': center,
            'scale': scale,
            'test': '',
            'img_name': 'video_' + str(index) + ".jpg"
        }
        ret.update({'meta': meta})

        return ret

    def __len__(self):
        return int(self.video.get(cv2.CAP_PROP_FRAME_COUNT))

Step 2: 添加视频数据加载函数

  • e2ec/dataset/data_loader.py 中添加
def make_demo_video_loader(data_root=None, cfg=None):
    from .demo_dataset_video import DatasetVideo
    batch_size = 1
    shuffle = False
    drop_last = False
    dataset = DatasetVideo(data_root, cfg)
    sampler = make_data_sampler(dataset, shuffle)
    batch_sampler = make_batch_data_sampler(sampler, batch_size, drop_last)
    num_workers = 1
    collator = collate_batch
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        collate_fn=collator)
    return data_loader

Step 3: 新建主函数文件

  • 新建 e2ec/visualize_video.py
from network import make_network
import tqdm
import torch
import os
import nms
import post_process
from dataset.data_loader import make_demo_video_loader
from train.model_utils.utils import load_network
import argparse
import importlib
import matplotlib.pyplot as plt
import numpy as np
from itertools import cycle

import PIL
import cv2
from io import BytesIO
import warnings
warnings.simplefilter("always")

parser = argparse.ArgumentParser()

parser.add_argument("config_file", help='/path/to/config_file.py')
parser.add_argument("--video", help='/path/to/video')
parser.add_argument(
    "--checkpoint", default='', help='/path/to/model_weight.pth')
parser.add_argument(
    "--ct_score",
    default=0.3,
    help='threshold to filter instances',
    type=float)
parser.add_argument(
    "--with_nms",
    default=False,
    type=bool,
    help='if True, will use nms post-process operation',
    choices=[True, False])
parser.add_argument(
    "--with_post_process",
    default=False,
    type=bool,
    help='if True, Will filter out some jaggies',
    choices=[True, False])
parser.add_argument(
    "--stage",
    default='final-dml',
    help='which stage of the contour will be generated',
    choices=['init', 'coarse', 'final', 'final-dml'])
parser.add_argument("--output_path", default='None', help='/path/to/output_dir/output.mp4')
parser.add_argument("--device", default=0, type=int, help='device idx')

args = parser.parse_args()


def get_cfg(args):
    cfg = importlib.import_module('configs.' + args.config_file).config
    cfg.test.with_nms = bool(args.with_nms)
    cfg.test.test_stage = args.stage
    cfg.test.ct_score = args.ct_score
    return cfg


def bgr_to_rgb(img):
    return img[:, :, [2, 1, 0]]


def unnormalize_img(img, mean, std):
    """
    img: [3, h, w]
    """
    img = img.detach().cpu().clone()
    img *= torch.tensor(std).view(3, 1, 1)
    img += torch.tensor(mean).view(3, 1, 1)
    min_v = torch.min(img)
    img = (img - min_v) / (torch.max(img) - min_v)
    return img

class VideoWriter:
    def __init__(self, name, width, height, fps=25):
        # type: (str, int, int, int) -> None
        if not name.endswith('.mp4'):  # 保证文件名的后缀是.mp4
            name += '.mp4'
            warnings.warn('video name should ends with ".mp4"')
        self.__name = name  # 文件名
        self.__height = height  # 高
        self.__width = width  # 宽
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # 如果是mp4视频编码需要为mp4v
        self.__writer = cv2.VideoWriter(name, fourcc, fps, (width, height))

    def write(self, frame):
        if frame.dtype != np.uint8:  # 检查frame的类型
            raise ValueError('frame.dtype should be np.uint8')
        # 检查frame的大小
        row, col, _ = frame.shape
        if row != self.__height or col != self.__width:
            warnings.warn('长和宽不等于创建视频写入时的设置此frame不会被写入视频')
            return
        self.__writer.write(frame)

    def close(self):
        self.__writer.release()

class VisualizerVideo(object):
    def __init__(self, cfg):
        self.cfg = cfg

    def visualize_ex(self, output, batch):
        inp = bgr_to_rgb(
            unnormalize_img(batch['inp'][0], self.cfg.data.mean,
                            self.cfg.data.std).permute(1, 2, 0))
        ex = output['py']
        ex = ex[-1] if isinstance(ex, list) else ex
        ex = ex.detach().cpu().numpy()

        fig, ax = plt.subplots(1, figsize=(20, 10))
        fig.tight_layout()
        ax.axis('off')
        ax.imshow(inp)

        colors = np.array([[31, 119, 180], [255, 127, 14], [46, 160, 44],
                           [214, 40, 39], [148, 103, 189], [140, 86, 75],
                           [227, 119, 194], [126, 126, 126], [188, 189, 32],
                           [26, 190, 207]]) / 255.
        np.random.shuffle(colors)
        colors = cycle(colors)
        for i in range(len(ex)):
            color = next(colors).tolist()
            poly = ex[i]
            poly = np.append(poly, [poly[0]], axis=0)
            ax.plot(poly[:, 0], poly[:, 1], color=color, lw=2)
        #申请缓冲地址
        buffer_ = BytesIO() #using buffer,great way!
        #保存在内存中而不是在本地磁盘注意这个默认认为你要保存的就是plt中的内容
        plt.savefig(buffer_, format = 'png', bbox_inches='tight')
        buffer_.seek(0)
        #用PIL或CV2从内存中读取
        dataPIL = PIL.Image.open(buffer_)
        #转换为nparraryPIL转换就非常快了,data即为所需
        data = np.asarray(dataPIL)
        buffer_.close()
        return data


def run_visualize(cfg):
    network = make_network.get_network(cfg).cuda()
    load_network(network, args.checkpoint)
    network.eval()

    data_loader = make_demo_video_loader(args.video, cfg=cfg)
    visualizer = VisualizerVideo(cfg)

    # 获取输入视频信息
    video = cv2.VideoCapture(args.video)
    width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
    frames_per_second = video.get(cv2.CAP_PROP_FPS)
    num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
    video.release()

    # 预测并写入视频文件
    vw = VideoWriter(args.output_path, 1264, height, frames_per_second)

    i = 0
    for batch in tqdm.tqdm(data_loader):
        for k in batch:
            if k != 'meta':
                batch[k] = batch[k].cuda()
        with torch.no_grad():
            output = network(batch['inp'], batch)
        if args.with_post_process:
            post_process.post_process(output)
        if args.with_nms:
            nms.post_process(output)
        output_frame = visualizer.visualize_ex(output, batch)
        output_frame = cv2.cvtColor(output_frame, cv2.COLOR_BGRA2RGB)
        cv2.imwrite("data/output_video/" + str(i) + ".jpg", output_frame)
        # 对视频文件写入每帧图像
        vw.write(output_frame)
        i = i + 1
        if i > 20:
            break
    # 关闭
    vw.close()


if __name__ == "__main__":
    cfg = get_cfg(args)
    torch.cuda.set_device(args.device)
    run_visualize(cfg)

参考链接

  1. python opencv 读取指定帧 或将整个视频读取为numpy array
  2. Python: Matplotlab 的 figure转换为numpy的arrary方法
  3. Python OpenCV 写入视频
阿里云国内75折 回扣 微信号:monov8
阿里云国际,腾讯云国际,低至75折。AWS 93折 免费开户实名账号 代冲值 优惠多多 微信号:monov8 飞机:@monov6