GOT-OCR-Optimize / got_ocr.py
Mageia's picture
fix: got_ocr
fffa248 unverified
raw
history blame
2.45 kB
import base64
import os
def got_ocr(model, tokenizer, image_path, got_mode="format texts OCR", fine_grained_mode="", ocr_color="", ocr_box=""):
# 执行OCR
try:
if got_mode == "plain texts OCR":
res = model.chat(tokenizer, image_path, ocr_type="ocr")
return res, None
elif got_mode == "format texts OCR":
result_path = f"{os.path.splitext(image_path)[0]}_result.html"
res = model.chat(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path)
elif got_mode == "plain multi-crop OCR":
res = model.chat_crop(tokenizer, image_path, ocr_type="ocr")
return res, None
elif got_mode == "format multi-crop OCR":
result_path = f"{os.path.splitext(image_path)[0]}_result.html"
res = model.chat_crop(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path)
elif got_mode == "plain fine-grained OCR":
res = model.chat(tokenizer, image_path, ocr_type="ocr", ocr_box=ocr_box, ocr_color=ocr_color)
return res, None
elif got_mode == "format fine-grained OCR":
result_path = f"{os.path.splitext(image_path)[0]}_result.html"
res = model.chat(tokenizer, image_path, ocr_type="format", ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)
# 处理格式化结果
if "format" in got_mode and os.path.exists(result_path):
with open(result_path, "r") as f:
html_content = f.read()
encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8")
return res, encoded_html
else:
return res, None
except Exception as e:
return f"错误: {str(e)}", None
# 使用示例
if __name__ == "__main__":
from modelscope import AutoModel, AutoTokenizer
# 初始化模型和分词器
tokenizer = AutoTokenizer.from_pretrained("stepfun-ai/GOT-OCR2_0", trust_remote_code=True)
model = AutoModel.from_pretrained("stepfun-ai/GOT-OCR2_0", trust_remote_code=True, low_cpu_mem_usage=True, device_map="cuda", use_safetensors=True)
model = model.eval().cuda()
image_path = "path/to/your/image.png"
result, html = got_ocr(model, tokenizer, image_path, got_mode="format texts OCR")
print("OCR结果:", result)
if html:
print("HTML结果可用")