您的位置:首页 > 科技 > IT业 > 深度学习系列71:表格检测和识别

深度学习系列71:表格检测和识别

2024/12/23 12:21:12 来源:https://blog.csdn.net/kittyzc/article/details/141381625  浏览:    关键词:深度学习系列71:表格检测和识别

1. pdf转图片

使用pdf2image库:

from pdf2image import convert_from_path
img = np.array(convert_from_path(path, dpi=800, use_cropbox=True)[0])

2. 表格位置检测

2.1 使用ppstructure

使用paddleocr库中的ppstructure可以方便获取表格位置,参考代码:

from paddleocr import PPStructure
structure = table_engine(source_img)

2.2 使用tabledetector

import tabledetector as td
result = td.detect(pdf_path="pdf_path", type="bordered", rotation=False, method='detect')

2.3 使用cv2的图形学方法

调试简单,具体代码如下:

  1. 二值化去除水印
    在这里插入图片描述
  2. 使用getStructuringElement获取纵线和横线
    在这里插入图片描述
  3. 两者合并,使用findContours获取表格外边框和内部单元格 在这里插入图片描述

3. 位置确认

获取所有单元格后,使用下面的函数获取单元格的相对位置关系:

from typing import Dict, List, Tuple
import numpy as npclass TableRecover:def __init__(self,):passdef __call__(self, polygons: np.ndarray) -> Dict[int, Dict]:rows = self.get_rows(polygons)longest_col, each_col_widths, col_nums = self.get_benchmark_cols(rows, polygons)each_row_heights, row_nums = self.get_benchmark_rows(rows, polygons)table_res = self.get_merge_cells(polygons,rows,row_nums,col_nums,longest_col,each_col_widths,each_row_heights,)return table_res@staticmethoddef get_rows(polygons: np.array) -> Dict[int, List[int]]:"""对每个框进行行分类,框定哪个是一行的"""y_axis = polygons[:, 0, 1]if y_axis.size == 1:return {0: [0]}concat_y = np.array(list(zip(y_axis, y_axis[1:])))minus_res = concat_y[:, 1] - concat_y[:, 0]result = {}thresh = 5.0split_idxs = np.argwhere(minus_res > thresh).squeeze()if split_idxs.ndim == 0:split_idxs = split_idxs[None, ...]if max(split_idxs) != len(minus_res):split_idxs = np.append(split_idxs, len(minus_res))start_idx = 0for row_num, idx in enumerate(split_idxs):if row_num != 0:start_idx = split_idxs[row_num - 1] + 1result.setdefault(row_num, []).extend(range(start_idx, idx + 1))# 计算每一行相邻cell的iou,如果大于0.2,则合并为同一个cellreturn resultdef get_benchmark_cols(self, rows: Dict[int, List], polygons: np.ndarray) -> Tuple[np.ndarray, List[float], int]:longest_col = max(rows.values(), key=lambda x: len(x))longest_col_points = polygons[longest_col]longest_x = longest_col_points[:, 0, 0]theta = 10for row_value in rows.values():cur_row = polygons[row_value][:, 0, 0]range_res = {}for idx, cur_v in enumerate(cur_row):start_idx, end_idx = None, Nonefor i, v in enumerate(longest_x):if cur_v - theta <= v <= cur_v + theta:breakif cur_v > v:start_idx = icontinueif cur_v < v:end_idx = ibreakrange_res[idx] = [start_idx, end_idx]sorted_res = dict(sorted(range_res.items(), key=lambda x: x[0], reverse=True))for k, v in sorted_res.items():if v[0]==None or v[1]==None:continuelongest_x = np.insert(longest_x, v[1], cur_row[k])longest_col_points = np.insert(longest_col_points, v[1], polygons[row_value[k]], axis=0)# 求出最右侧所有cell的宽,其中最小的作为最后一列宽度rightmost_idxs = [v[-1] for v in rows.values()]rightmost_boxes = polygons[rightmost_idxs]min_width = min([self.compute_L2(v[3, :], v[0, :]) for v in rightmost_boxes])each_col_widths = (longest_x[1:] - longest_x[:-1]).tolist()each_col_widths.append(min_width)col_nums = longest_x.shape[0]return longest_col_points, each_col_widths, col_numsdef get_benchmark_rows(self, rows: Dict[int, List], polygons: np.ndarray) -> Tuple[np.ndarray, List[float], int]:leftmost_cell_idxs = [v[0] for v in rows.values()]benchmark_x = polygons[leftmost_cell_idxs][:, 0, 1]theta = 10# 遍历其他所有的框,按照y轴进行区间划分range_res = {}for cur_idx, cur_box in enumerate(polygons):if cur_idx in benchmark_x:continuecur_y = cur_box[0, 1]start_idx, end_idx = None, Nonefor i, v in enumerate(benchmark_x):if cur_y - theta <= v <= cur_y + theta:breakif cur_y > v:start_idx = icontinueif cur_y < v:end_idx = ibreakrange_res[cur_idx] = [start_idx, end_idx]sorted_res = dict(sorted(range_res.items(), key=lambda x: x[0], reverse=True))for k, v in sorted_res.items():if v[0]==None or v[1]==None:continuebenchmark_x = np.insert(benchmark_x, v[1], polygons[k][0, 1])each_row_widths = (benchmark_x[1:] - benchmark_x[:-1]).tolist()# 求出最后一行cell中,最大的高度作为最后一行的高度bottommost_idxs = list(rows.values())[-1]bottommost_boxes = polygons[bottommost_idxs]max_height = max([self.compute_L2(v[3, :], v[0, :]) for v in bottommost_boxes])each_row_widths.append(max_height)row_nums = benchmark_x.shape[0]return each_row_widths, row_nums@staticmethoddef compute_L2(a1: np.ndarray, a2: np.ndarray) -> float:return np.linalg.norm(a2 - a1)def get_merge_cells(self,polygons: np.ndarray,rows: Dict,row_nums: int,col_nums: int,longest_col: np.ndarray,each_col_widths: List[float],each_row_heights: List[float],) -> Dict[int, Dict[int, int]]:col_res_merge, row_res_merge = {}, {}merge_thresh = 20for cur_row, col_list in rows.items():one_col_result, one_row_result = {}, {}for one_col in col_list:box = polygons[one_col]box_width = self.compute_L2(box[3, :], box[0, :])# 不一定是从0开始的,应该综合已有值和x坐标位置来确定起始位置loc_col_idx = np.argmin(np.abs(longest_col[:, 0, 0] - box[0, 0]))merge_col_cell = max(sum(one_col_result.values()), loc_col_idx)# 计算合并多少个列方向单元格for i in range(merge_col_cell, col_nums):col_cum_sum = sum(each_col_widths[merge_col_cell : i + 1])if i == merge_col_cell and col_cum_sum > box_width:one_col_result[one_col] = 1breakelif abs(col_cum_sum - box_width) <= merge_thresh:one_col_result[one_col] = i + 1 - merge_col_cellbreakelse:one_col_result[one_col] = i + 1 - merge_col_cell + 1box_height = self.compute_L2(box[1, :], box[0, :])merge_row_cell = cur_rowfor j in range(merge_row_cell, row_nums):row_cum_sum = sum(each_row_heights[merge_row_cell : j + 1])# box_height 不确定是几行的高度,所以要逐个试验,找一个最近的几行的高# 如果第一次row_cum_sum就比box_height大,那么意味着?丢失了一行if j == merge_row_cell and row_cum_sum > box_height:one_row_result[one_col] = 1breakelif abs(box_height - row_cum_sum) <= merge_thresh:one_row_result[one_col] = j + 1 - merge_row_cellbreakelse:one_row_result[one_col] = j + 1 - merge_row_cell + 1col_res_merge[cur_row] = one_col_resultrow_res_merge[cur_row] = one_row_resultres = {}for i, (c, r) in enumerate(zip(col_res_merge.values(), row_res_merge.values())):res[i] = {k: [cc, r[k]] for k, cc in c.items()}return res

调用代码如下:

h_min = 10
h_max = 5000def sortContours(cnts, method='left-to-right'):reverse = Falsei = 0if method == "right-to-left" or method == "bottom-to-top":reverse = Trueif method == "top-to-bottom" or method == "bottom-to-top":i = 1boundingBoxes = [cv2.boundingRect(c) for c in cnts](cnts, boundingBoxes) = zip(*sorted(zip(cnts, boundingBoxes),key=lambda b: b[1][i], reverse=reverse))return (cnts, boundingBoxes)def sorted_boxes(dt_boxes):num_boxes = dt_boxes.shape[0]dt_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))_boxes = list(dt_boxes)for i in range(num_boxes - 1):for j in range(i, -1, -1):if (abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10and _boxes[j + 1][0][0] < _boxes[j][0][0]):_boxes[j], _boxes[j + 1] = _boxes[j + 1], _boxes[j]else:breakreturn _boxesdef getBboxDtls(raw):######### 1. 获得表格的边框,确保merge正确展示了图中的表格边框gray = cv2.cvtColor(raw, cv2.COLOR_BGR2GRAY)binary = 255-cv2.threshold(gray, 200, 255, cv2.THRESH_BINARY)[1]rows, cols = binary.shapescale = 30kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (cols // scale, 1))eroded = cv2.erode(binary, kernel, iterations=1)dilated_col = cv2.dilate(eroded, kernel, iterations=1)scale = 20kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, rows // scale))eroded = cv2.erode(binary, kernel, iterations=1)dilated_row = cv2.dilate(eroded, kernel, iterations=1)merge = cv2.add(dilated_col, dilated_row)kernel = np.ones((3,3),np.uint8)merge = cv2.erode(cv2.dilate(merge, kernel, iterations=3), kernel, iterations=3)plt.figure(figsize=(60,30))io.imshow(merge[1500:2500])########## 2. 获取表格坐标tableData = []contours = cv2.findContours(merge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)contours = contours[0] if len(contours) == 2 else contours[1]contours, boundingBoxes = sortContours(contours, method='top-to-bottom')# 获取表格外边框for c in contours:x, y, w, h = cv2.boundingRect(c)if (h>h_min):tableData.append((x, y, w, h))        # 获取表格内部的单元格contours, hierarchy = cv2.findContours(merge, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)contours, boundingBoxes = sortContours(contours, method="top-to-bottom")boxes = []for c in contours:x, y, w, h = cv2.boundingRect(c)if (h>h_min) and (h<h_max):boxes.append([x, y, w, h])########## 3. 计算表格单元格位置关系bboxDtls = {}for tableBox1 in tableData:key = tableBox1values = []for tableBox2 in boxes:x2, y2, w2, h2 = tableBox2if tableBox1[0] <= x2 <= tableBox1[0] + tableBox1[2] and tableBox1[1] <= y2 <= tableBox1[1] + tableBox1[3]:values.append(tableBox2)bboxDtls[key] = valuesfor key, values in bboxDtls.items():x_tab, y_tab, w_tab, h_tab = keyfor box in values:x_box, y_box, w_box, h_box = boxreturn bboxDtls

4. 表格文字识别

4.1 wired_table_rec

可以尝试使用wired_table_rec进行识别:

from wired_table_rec import WiredTableRecognition
table_rec = WiredTableRecognition()
table_str = table_rec(cv2.imread(img_path))[0]
HTML(table_str)

4.2 rapidocr_onnxruntime或者pytessect

或者可以使用更原子化的ocr服务,逐个单元格进行ocr识别,完整代码如下:

"""
首先安装pdf2image和rapidocr_onnxruntime两个库。
图像处理部分的参数和代码可以自行调整:
1. pad参数用于去除图片的边框
2. 转pdf时,有时候800dpi会失败,因此需要加入try except
3. 图片太小时ocr效果不好,因此做了resize。这里的3000可以自行调整。
4. 大图片做了二值化处理,目的是去除水印的干扰。这里的180也可以尝试自行调整。
5. 只处理第一个单元格总数大于50的表格。如果要识别图片中所有表格,可修改代码。
6. 返回的是html格式的表格,可以用pd.read_html函数转为dataframe
"""
rocr = RapidOCR()
rocr.text_det.preprocess_op = DetPreProcess(736, 'max')
def getResult(path,pad = 20, resize_thresh=3000, binary_thresh=180):if 'pdf' in path:try:source_img = np.array(convert_from_path(path, dpi=800, use_cropbox=True)[0])[pad:-pad,pad:-pad]except:source_img = np.array(convert_from_path(path, dpi=300, use_cropbox=True)[0])[pad:-pad,pad:-pad]else:source_img = cv2.imread(path)[pad:-pad,pad:-pad]if source_img.shape[1] < resize_thresh:source_img =cv2.resize(source_img,(resize_thresh,int(source_img.shape[0]/source_img.shape[1]*resize_thresh)))img = cv2.threshold(cv2.cvtColor(source_img, cv2.COLOR_BGR2GRAY), binary_thresh, 255, cv2.THRESH_BINARY)[1] bboxDtls = getBboxDtls(source_img)boxes = []table = None# 寻找到第一个单元格数大于50的表后停止for k,v in bboxDtls.items():if len(v)>50:table = kfor r in tqdm(v[1:]):res = rocr(img[r[1]: r[1]+r[3],r[0]:r[2]+r[0]])[0]if res!=None:res.sort(key = lambda x:(x[0][0][1]//(img.shape[1]//20),x[0][0][0]//(img.shape[0]//20)))boxes.append([[r[0], r[1]], [r[0], r[1]+r[3]],[r[0]+r[2], r[1]+r[3]], [r[0]+r[2], r[1]], ''.join([t[1].replace('\n','').replace(' ','') for t in res])])else:boxes.append([[r[0], r[1]], [r[0], r[1]+r[3]],[r[0]+r[2], r[1]+r[3]], [r[0]+r[2], r[1]], ''])  breakpolygons = sorted_boxes(np.array(boxes))texts = [p[4] for p in polygons]tr = TableRecover()table_res = tr(np.array([[np.array(p[0]),np.array(p[1]),np.array(p[2]),np.array(p[3])] for p in polygons]))table_html = """<table border="1" cellspacing="0">"""for vs in table_res.values():table_html+="<tr>"for i,v in vs.items():table_html+=f"""<td colspan="{v[0]}" rowspan="{v[1]}">{texts[i]}</td>"""table_html+="</tr>"table_html+="""</table>"""return table_html

原图为:https://www.95598.cn/omg-static/99107281818076039603801539578309.jpg
最终识别出来的结果如下:
在这里插入图片描述

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com