CRNN模型Python实现笔记三
阿里云国内75折 回扣 微信号:monov8 |
阿里云国际,腾讯云国际,低至75折。AWS 93折 免费开户实名账号 代冲值 优惠多多 微信号:monov8 飞机:@monov6 |
文章目录
一、函数讲解
1. numel()
函数
在Pytorch中, numel函数是torch.Tensor类的一个方法,它可以返回张量中的元素总数。
例如
import torch
x = torch.randn(2, 3)
print(x.numel())
这将输出6因为x有2行3列总共有6个元素。
值得注意的是,torch.Tensor.nelement()也和torch.Tensor.numel()
做同样的事情,所以您可以使用任何一个。
二、疑难代码段理解
1. strLabelConverter
类
# copy from utils
class strLabelConverter(object):
def __init__(self, alphabet, ignore_case=False):
self._ignore_case = ignore_case
if self._ignore_case:
alphabet = alphabet.lower()
self.alphabet = alphabet + '_' # for `-1` index
self.dict = {}
for i, char in enumerate(alphabet):
# NOTE: 0 is reserved for 'blank' required by wrap_ctc
self.dict[char] = i + 1
# print(self.dict)
def encode(self, text):
length = []
result = []
for item in text:
item = item.decode('utf-8', 'strict')
length.append(len(item))
for char in item:
if char not in self.dict.keys():
index = 0
else:
index = self.dict[char]
result.append(index)
text = result
return (torch.IntTensor(text), torch.IntTensor(length))
def decode(self, t, length, raw=False):
if length.numel() == 1:
length = length[0]
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(),
length)
if raw:
return ''.join([self.alphabet[i - 1] for i in t])
else:
char_list = []
for i in range(length):
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
char_list.append(self.alphabet[t[i] - 1])
return ''.join(char_list)
else:
# batch mode
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(
t.numel(), length.sum())
texts = []
index = 0
for i in range(length.numel()):
l = length[i]
texts.append(
self.decode(
t[index:index + l], torch.IntTensor([l]), raw=raw))
index += l
return texts
这段代码是定义了一个 strLabelConverter
类它主要用于文本编码和解码。
类中定义了三个函数
-
init(self, alphabet, ignore_case=False)
初始化函数用于创建一个strLabelConverter
实例。alphabet 是字符集ignore_case=False 表示是否忽略大小写。 -
encode(self, text)
将文本编码为整数序列。 -
decode(self, t, length, raw=False)
将整数序列解码为文本。
这个类中使用了字典 self.dict
和字符集 self.alphabet
将文本编码为整数序列和将整数序列解码为文本。
其中, self 是类中所有函数的第一个参数它代表的是类的实例本身在类的函数中可以通过 self 引用类的其它成员。
(1) def encode(self, text):
函数的作用
这段代码是在strLabelConverter类中encode函数中,主要用于将文本编码为整数序列。
item = item.decode('utf-8', 'strict')
Python中字符串默认是以unicode编码的如果字符串是以其他编码格式存储的那么就需要使用decode()
函数进行解码。
decode()
函数接受两个参数第一个参数是解码的编码格式第二个参数是非法字节序列的处理方式'strict’表示如果遇到非法字节序列将会抛出一个UnicodeDecodeError异常。
第二个参数的值有’strict’,‘ignore’,‘replace’,‘backslashreplace’ 四种。 'strict'
表示遇到非法字节序列会抛出异常 'ignore'
表示忽略非法字节序列'replace'
表示替换非法字节序列'backslashreplace'
表示反斜杠替换。
首先定义了两个空的列表length
和result
。然后遍历text中的每一项对每一项进行decode('utf-8', 'strict')
操作将其解码为utf-8格式。接着将每一项的长度添加到length列表中。
接下来对每一项中的每一个字符进行遍历如果字符不在self.dict.keys()中将其索引赋值为0否则将其索引赋值为self.dict[char]。最后把索引添加到result列表中。
最后将result赋值给text并用torch.IntTensor()
将text
和length
转换为张量后返回。
总之这个函数的作用是将文本中的每个字符转换为相应的索引并返回索引和文本长度的张量。
(2) def decode(self, text):
函数的作用
raw=False
有什么用?
raw=False
表示在解码时只保留非重复字符。
在strLabelConverter
类中decode函数中如果raw
参数为False
,那么对于t中的每一个整数序列如果不为0并且不是重复的那么将其对应的字符添加到char_list
中最后将char_list
拼接成字符串并返回。
这样的做法的意图是去除重复的字符,在实际应用中可能用于去除预测结果中的重复字符如果raw=True
则会返回所有字符。
(3) 关于函数定义中self
的疑惑
python在类中定义函数为什么参数列表里面都有
self
?
在 Python 中类中定义函数的参数列表里面加入 self
是为了让函数能够访问类的其它成员也是一种约定俗成的用法。
self
是类中所有函数的第一个参数它代表的是类的实例本身在类的函数中可以通过 self 引用类的其它成员。
当一个类的实例调用一个函数时系统会自动传入这个实例作为第一个参数这样函数就能访问类的其它成员。
简而言之self 参数起到类和函数之间连接的作用让函数能够访问类。
三、附录crnn_recognizer.py
import torch.nn as nn
# import torchvision.models as models
import torch, os
from PIL import Image
import cv2
import torchvision.transforms as transforms
from torch.autograd import Variable
import numpy as np
import random
from crnn import CRNN
import config
# copy from mydataset
class resizeNormalize(object):
def __init__(self, size, interpolation=Image.LANCZOS, is_test=True):
self.size = size
self.interpolation = interpolation
self.toTensor = transforms.ToTensor()
self.is_test = is_test
def __call__(self, img):
w, h = self.size
w0 = img.size[0]
h0 = img.size[1]
if w <= (w0 / h0 * h):
img = img.resize(self.size, self.interpolation)
img = self.toTensor(img)
img.sub_(0.5).div_(0.5)
else:
w_real = int(w0 / h0 * h)
img = img.resize((w_real, h), self.interpolation)
img = self.toTensor(img)
img.sub_(0.5).div_(0.5)
tmp = torch.zeros([img.shape[0], h, w])
start = random.randint(0, w - w_real - 1)
if self.is_test:
start = 0
tmp[:, :, start:start + w_real] = img
img = tmp
return img
# copy from utils
class strLabelConverter(object):
def __init__(self, alphabet, ignore_case=False):
self._ignore_case = ignore_case
if self._ignore_case:
alphabet = alphabet.lower()
self.alphabet = alphabet + '_' # for `-1` index
self.dict = {}
for i, char in enumerate(alphabet):
# NOTE: 0 is reserved for 'blank' required by wrap_ctc
self.dict[char] = i + 1
# print(self.dict)
def encode(self, text):
length = []
result = []
for item in text:
item = item.decode('utf-8', 'strict')
length.append(len(item))
for char in item:
if char not in self.dict.keys():
index = 0
else:
index = self.dict[char]
result.append(index)
text = result
return (torch.IntTensor(text), torch.IntTensor(length))
def decode(self, t, length, raw=False):
if length.numel() == 1:
length = length[0]
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(),
length)
if raw:
return ''.join([self.alphabet[i - 1] for i in t])
else:
char_list = []
for i in range(length):
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
char_list.append(self.alphabet[t[i] - 1])
return ''.join(char_list)
else:
# batch mode
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(
t.numel(), length.sum())
texts = []
index = 0
for i in range(length.numel()):
l = length[i]
texts.append(
self.decode(
t[index:index + l], torch.IntTensor([l]), raw=raw))
index += l
return texts
# recognize api
class PytorchOcr():
def __init__(self, model_path):
alphabet_unicode = config.alphabet_v2
self.alphabet = ''.join([chr(uni) for uni in alphabet_unicode])
# print(len(self.alphabet))
self.nclass = len(self.alphabet) + 1
self.model = CRNN(config.imgH, 1, self.nclass, 256)
self.cuda = False
if torch.cuda.is_available():
self.cuda = True
self.model.cuda()
self.model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(model_path).items()})
else:
# self.model = nn.DataParallel(self.model)
self.model.load_state_dict(torch.load(model_path, map_location='cpu'))
self.model.eval()
self.converter = strLabelConverter(self.alphabet)
def recognize(self, img):
h,w = img.shape[:2]
if len(img.shape) == 3:
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
image = Image.fromarray(img)
transformer = resizeNormalize((int(w/h*32), 32))
image = transformer(image)
image = image.view(1, *image.size())
image = Variable(image)
if self.cuda:
image = image.cuda()
preds = self.model(image)
_, preds = preds.max(2)
preds = preds.transpose(1, 0).contiguous().view(-1)
preds_size = Variable(torch.IntTensor([preds.size(0)]))
txt = self.converter.decode(preds.data, preds_size.data, raw=False)
return txt
if __name__ == '__main__':
model_path = './crnn_models/CRNN-1008.pth'
recognizer = PytorchOcr(model_path)
img_name = 't1.jpg'
img = cv2.imread(img_name)
h, w = img.shape[:2]
res = recognizer.recognize(img)
print(res)