GH/utils/aliyun_ocr.py

229 lines
8.9 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""
阿里云OCR服务集成
使用阿里云AI大模型进行图片文字识别
"""
import base64
import json
import os
from dotenv import load_dotenv
from alibabacloud_ocr_api20210707.client import Client as ocr_api20210707Client
from alibabacloud_tea_openapi import models as open_api_models
from alibabacloud_ocr_api20210707 import models as ocr_api20210707_models
from alibabacloud_tea_util import models as util_models
from alibabacloud_tea_util.client import Client as UtilClient
# 加载环境变量
load_dotenv()
class AliyunOCR:
"""阿里云OCR服务类"""
def __init__(self, access_key_id=None, access_key_secret=None, endpoint=None):
"""初始化阿里云OCR客户端"""
self.access_key_id = access_key_id or os.getenv('ALIYUN_ACCESS_KEY_ID')
self.access_key_secret = access_key_secret or os.getenv('ALIYUN_ACCESS_KEY_SECRET')
self.endpoint = endpoint or os.getenv('ALIYUN_OCR_ENDPOINT', 'ocr-api.cn-hangzhou.aliyuncs.com')
if not self.access_key_id or not self.access_key_secret:
raise Exception("阿里云AccessKey未配置请在.env文件中设置ALIYUN_ACCESS_KEY_ID和ALIYUN_ACCESS_KEY_SECRET")
# 创建配置对象
config = open_api_models.Config(
access_key_id=self.access_key_id,
access_key_secret=self.access_key_secret
)
config.endpoint = self.endpoint
# 创建客户端
self.client = ocr_api20210707Client(config)
def recognize_general(self, image_path):
"""通用文字识别"""
try:
# 读取图片并编码为base64
with open(image_path, 'rb') as image_file:
image_data = base64.b64encode(image_file.read()).decode('utf-8')
# 创建请求
recognize_general_request = ocr_api20210707_models.RecognizeGeneralRequest(
image_url='', # 使用image_data所以这里留空
body=util_models.RuntimeOptions()
)
# 设置图片数据
recognize_general_request.body = image_data
# 发送请求
response = self.client.recognize_general(recognize_general_request)
# 解析响应
if response.body.code == 200:
result = json.loads(response.body.data)
return self._extract_text(result)
else:
raise Exception(f"阿里云OCR识别失败: {response.body.message}")
except Exception as e:
raise Exception(f"阿里云OCR识别错误: {str(e)}")
def recognize_advanced(self, image_path, options=None):
"""高级文字识别(支持更多功能)"""
try:
# 读取图片并编码为base64
with open(image_path, 'rb') as image_file:
image_data = base64.b64encode(image_file.read()).decode('utf-8')
# 创建请求
recognize_advanced_request = ocr_api20210707_models.RecognizeAdvancedRequest(
image_url='',
body=util_models.RuntimeOptions()
)
# 设置图片数据
recognize_advanced_request.body = image_data
# 设置高级选项
if options:
if 'output_char_info' in options:
recognize_advanced_request.output_char_info = options['output_char_info']
if 'output_table' in options:
recognize_advanced_request.output_table = options['output_table']
if 'need_rotate' in options:
recognize_advanced_request.need_rotate = options['need_rotate']
# 发送请求
response = self.client.recognize_advanced(recognize_advanced_request)
# 解析响应
if response.body.code == 200:
result = json.loads(response.body.data)
return self._extract_text(result)
else:
raise Exception(f"阿里云高级OCR识别失败: {response.body.message}")
except Exception as e:
raise Exception(f"阿里云高级OCR识别错误: {str(e)}")
def recognize_table(self, image_path):
"""表格识别"""
try:
# 读取图片并编码为base64
with open(image_path, 'rb') as image_file:
image_data = base64.b64encode(image_file.read()).decode('utf-8')
# 创建请求
recognize_table_request = ocr_api20210707_models.RecognizeTableRequest(
image_url='',
body=util_models.RuntimeOptions()
)
# 设置图片数据
recognize_table_request.body = image_data
# 发送请求
response = self.client.recognize_table(recognize_table_request)
# 解析响应
if response.body.code == 200:
result = json.loads(response.body.data)
return self._extract_table_data(result)
else:
raise Exception(f"阿里云表格识别失败: {response.body.message}")
except Exception as e:
raise Exception(f"阿里云表格识别错误: {str(e)}")
def _extract_text(self, result):
"""从OCR结果中提取文本"""
text = ""
if 'content' in result:
# 简单文本识别结果
text = result['content']
elif 'prism_wordsInfo' in result:
# 结构化识别结果
words_info = result['prism_wordsInfo']
for word_info in words_info:
if 'word' in word_info:
text += word_info['word'] + "\n"
elif 'prism_tablesInfo' in result:
# 表格识别结果
tables_info = result['prism_tablesInfo']
for table_info in tables_info:
if 'cellContents' in table_info:
for cell in table_info['cellContents']:
if 'word' in cell:
text += cell['word'] + "\t"
text += "\n"
return text.strip()
def _extract_table_data(self, result):
"""提取表格数据"""
table_data = []
if 'content' in result:
# 直接返回内容
return result['content']
elif 'prism_tablesInfo' in result:
# 结构化表格数据
tables_info = result['prism_tablesInfo']
for table_info in tables_info:
table_rows = []
if 'cellContents' in table_info:
# 按行组织数据
max_row = max([cell.get('row', 0) for cell in table_info['cellContents']]) + 1
max_col = max([cell.get('col', 0) for cell in table_info['cellContents']]) + 1
# 创建空表格
table = [['' for _ in range(max_col)] for _ in range(max_row)]
# 填充数据
for cell in table_info['cellContents']:
row = cell.get('row', 0)
col = cell.get('col', 0)
word = cell.get('word', '')
if row < max_row and col < max_col:
table[row][col] = word
# 转换为文本格式
for row in table:
table_rows.append('\t'.join(row))
table_data.append('\n'.join(table_rows))
return '\n\n'.join(table_data) if table_data else "未识别到表格数据"
def extract_text_with_aliyun(image_path, ocr_type='general', options=None):
"""使用阿里云OCR提取图片文字"""
try:
ocr_client = AliyunOCR()
if ocr_type == 'general':
return ocr_client.recognize_general(image_path)
elif ocr_type == 'advanced':
return ocr_client.recognize_advanced(image_path, options)
elif ocr_type == 'table':
return ocr_client.recognize_table(image_path)
else:
raise Exception(f"不支持的OCR类型: {ocr_type}")
except Exception as e:
raise Exception(f"阿里云OCR识别失败: {str(e)}")
def check_aliyun_config():
"""检查阿里云配置是否完整"""
access_key_id = os.getenv('ALIYUN_ACCESS_KEY_ID')
access_key_secret = os.getenv('ALIYUN_ACCESS_KEY_SECRET')
if not access_key_id or not access_key_secret:
return False, "阿里云AccessKey未配置"
try:
# 测试连接
ocr_client = AliyunOCR()
return True, "阿里云OCR配置正确"
except Exception as e:
return False, f"阿里云OCR配置错误: {str(e)}"