RESA车道线路沿检测

阿里云国内75折 回扣 微信号:monov8
阿里云国际,腾讯云国际,低至75折。AWS 93折 免费开户实名账号 代冲值 优惠多多 微信号:monov8 飞机:@monov6

一、当前车道线检测遇到的问题

1、车道标注中固有的稀疏监督信号使其一直很有挑战性

2、传统卷积不能很有效的提取细长的车道线和路沿(方格内有效特征很少)没有利用形状先验

3、SCNN提出在行列间传递信息但是顺序信息传递是耗时的相邻行列传递信息需要多次迭代长距离传播容易丢失信息

二、本论文提出的方法

1、提出RESA模块利用车道的强形状先验捕获行列间的空间关系

a.   采用并行方式传递信息大大降低时间成本。

b.   信息以不同的步长传播一个像素多次叠加另一个像素防止长距离信息丢失同时使每个像素都能够收集全局信息。

c.    可以很方便的合并到其他网络

2、提出双边上采样解码器BUSD

一个分支用于捕获粗粒度特征upsample上采样算子采用双线性插值

一个分支用于捕获细粒度特征转置卷积+两个non-bottleneck修复细微损失

可以将低分辨率特征精确的恢复为逐像素特征

三、模型结构

 编码层ResNet或VGG

RESA层

BUSD双边上采样层

四、代码解析

class RESA(nn.Module):
    def __init__(self, cfg):
        super(RESA, self).__init__()
        self.iter = cfg.resa.iter  # 5
        chan = cfg.resa.input_channel  #128
        fea_stride = cfg.backbone.fea_stride  #8
        self.height = cfg.img_height // fea_stride
        self.width = cfg.img_width // fea_stride
        self.alpha = cfg.resa.alpha  # 2
        conv_stride = cfg.resa.conv_stride #9
        #每个方向的卷积都要迭代iter次初始化卷积
        for i in range(self.iter):
            #一行九列卷一下
            conv_vert1 = nn.Conv2d(
                chan, chan, (1, conv_stride),
                padding=(0, conv_stride//2), groups=1, bias=False)
            conv_vert2 = nn.Conv2d(
                chan, chan, (1, conv_stride),
                padding=(0, conv_stride//2), groups=1, bias=False)
            #setattr(object, name, value)用于设置属性值
            setattr(self, 'conv_d'+str(i), conv_vert1)
            setattr(self, 'conv_u'+str(i), conv_vert2)
            #九行一列卷一下
            conv_hori1 = nn.Conv2d(
                chan, chan, (conv_stride, 1),
                padding=(conv_stride//2, 0), groups=1, bias=False)
            conv_hori2 = nn.Conv2d(
                chan, chan, (conv_stride, 1),
                padding=(conv_stride//2, 0), groups=1, bias=False)

            setattr(self, 'conv_r'+str(i), conv_hori1)
            setattr(self, 'conv_l'+str(i), conv_hori2)
            #[1,2,3,4.......,31,0]
            #[2,3,4,5,......,0,1]
            #[4,5,6,7,......,2,3]
            #[8,9,10,.......,6,7]
            idx_d = (torch.arange(self.height) + self.height //
                     2**(self.iter - i)) % self.height
            setattr(self, 'idx_d'+str(i), idx_d)

            idx_u = (torch.arange(self.height) - self.height //
                     2**(self.iter - i)) % self.height
            setattr(self, 'idx_u'+str(i), idx_u)

            idx_r = (torch.arange(self.width) + self.width //
                     2**(self.iter - i)) % self.width
            setattr(self, 'idx_r'+str(i), idx_r)

            idx_l = (torch.arange(self.width) - self.width //
                     2**(self.iter - i)) % self.width
            setattr(self, 'idx_l'+str(i), idx_l)

    def forward(self, x):
        x = x.clone()

        for direction in ['d', 'u']:
            for i in range(self.iter):
                #获取对象属性
                conv = getattr(self, 'conv_' + direction + str(i))
                idx = getattr(self, 'idx_' + direction + str(i))
                #在一行九列上卷积在行上相加
                x.add_(self.alpha * F.relu(conv(x[..., idx, :])))

        for direction in ['r', 'l']:
            for i in range(self.iter):
                conv = getattr(self, 'conv_' + direction + str(i))
                idx = getattr(self, 'idx_' + direction + str(i))
                x.add_(self.alpha * F.relu(conv(x[..., idx])))

        return x

阿里云国内75折 回扣 微信号:monov8
阿里云国际,腾讯云国际,低至75折。AWS 93折 免费开户实名账号 代冲值 优惠多多 微信号:monov8 飞机:@monov6