Spaces:
Runtime error
Runtime error
import re | |
import gradio as gr | |
from typing import List | |
from models import models | |
from loguru import logger | |
import re | |
PROMPT_TEMPLATE = """\ | |
使用中文{query_str}: | |
{context_str} | |
""" | |
def get_text_lines(input_txt: str) -> List[str]: | |
lines = input_txt.splitlines() | |
lines = [line.strip() for line in lines if line.strip()] | |
return lines | |
stop_chars_set = { | |
'.', '!', '?', '。', '!', '?', '…', ';', ';', ':', ':', | |
'”', '’', ')', '】', '》', '」', '』', '〕', '〉', | |
'》', '〗', '〞', '〟', '»', '"', "'", ')', ']', '}' | |
} | |
def split_in_line(input_txt: str, limit_length: int) -> List[str]: | |
new_text = '' | |
contents = [] | |
outputs = [] | |
for text in input_txt: | |
new_text += text | |
if text in stop_chars_set: | |
contents.append(new_text) | |
# logger.debug(f"{new_text}") | |
new_text = '' | |
# logger.debug(f"{input_txt[-1]} {input_txt[-1] not in stop_chars_set} {new_text}") | |
if input_txt[-1] not in stop_chars_set: | |
contents.append(new_text) | |
text = "" | |
text_length = 0 | |
for idx, content in enumerate(contents): | |
text += content | |
text_length += len(content) | |
if text_length >= limit_length: | |
outputs.append(text) | |
text = "" | |
text_length = 0 | |
if text_length < limit_length: | |
outputs.append(text) | |
return outputs | |
def get_text_limit_length(input_txt: str, max_length: int = 2048) -> List[str]: | |
lines = get_text_lines(input_txt) | |
output: List[str] = [] | |
for line in lines: | |
if len(line) <= max_length: | |
output.append(line) | |
else: | |
text_lines = split_in_line(line, max_length) | |
logger.debug(f"split in line: {len(text_lines)}") | |
# logger.debug(f"{line} ==> {text_lines}") | |
output.extend(text_lines) | |
return output | |
def split_input_text(input_txt, strip_input_lines=0, max_length=2048): | |
if strip_input_lines > 0: | |
pattern = r'[\r\n]{' + str(strip_input_lines) + r',}' | |
re.compile(pattern=pattern) | |
logger.debug(f"strip input txt: {pattern}") | |
input_txt = re.sub(pattern, '', input_txt) | |
lines = get_text_limit_length(input_txt, max_length) | |
logger.debug(f"split input txt: {len(lines)}") | |
return "\n\n\n".join(lines) | |
def gen_keyword_summary(input_txt, keyword_prompt, summary_prompt, max_length=2048): | |
lines = input_txt.split("\n\n\n") | |
keywords_output = [] | |
for line in lines: | |
keywords = models.llm_model.generate_answer( | |
keyword_prompt, | |
line, | |
history=None, | |
max_length=max_length, | |
prompt_template=PROMPT_TEMPLATE | |
)[0] | |
logger.debug(f"text len: {len(line)} ==> {keywords}") | |
keywords_output.extend(keywords.split()) | |
keywords_output = [keyword.strip() for keyword in keywords_output if keyword.strip() != ""] | |
keywords_output = list(set(keywords_output)) | |
return f"保留关键词:{' '.join(keywords_output)},{summary_prompt}" | |
def gen_summary(input_txt, summary_prompt, max_length=2048): | |
lines = input_txt.split("\n\n\n") | |
output_summary = [] | |
summary = "" | |
for idx, line in enumerate(lines): | |
if idx == 1: | |
summary = models.llm_model.generate_answer( | |
summary_prompt, | |
line, | |
history=None, | |
max_length=max_length, | |
prompt_template=PROMPT_TEMPLATE | |
)[0] | |
logger.debug(f"text len: {len(line)} ==> {summary}") | |
else: | |
summary = models.llm_model.generate_answer( | |
summary_prompt, | |
f"{summary}{line}", | |
history=None, | |
max_length=max_length, | |
prompt_template=PROMPT_TEMPLATE | |
)[0] | |
logger.debug(f"summary: {len(summary)} + text: {len(line)} ==> {summary}") | |
output_summary.append(summary) | |
return "\n\n\n".join(output_summary) | |
def summary_ui(): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
line_max_length = gr.Slider(minimum=512, maximum=4096, step=1, value=1024, label="每行最大长度") | |
strip_input_lines = gr.Slider( | |
label="去除输入文本连续的空行(0:不除去)", | |
minimum=1, | |
maximum=10, | |
step=1, | |
value=0 | |
) | |
with gr.Column(scale=4): | |
keyword_prompt = gr.Textbox( | |
lines=1, | |
label="抽取关键词", | |
value="抽取以下内容的人物和地点:", | |
placeholder="请输入抽取关键词的Prompt" | |
) | |
summary_prompt = gr.Textbox( | |
lines=2, | |
label="生成摘要", | |
value="生成以下内容的摘要:", | |
placeholder="请输入生成摘要的Prompt" | |
) | |
keyword_summary_prompt = gr.Textbox(lines=4, label="关键词+摘要", placeholder="请输入关键词+摘要的Prompt") | |
with gr.Row(): | |
input_text = gr.Textbox(lines=20, max_lines=60, label="输入文本", placeholder="请输入文本") | |
split_text = gr.Textbox(lines=20, max_lines=60, label="分段文本", placeholder="请输入分段文本") | |
summary = gr.Textbox(lines=20, max_lines=60, label="生成摘要", placeholder="请输入生成摘要的Prompt") | |
with gr.Row(): | |
btn_split = gr.Button("分段") | |
btn_keyword = gr.Button("提取关键词") | |
btn_summary = gr.Button("生成摘要") | |
btn_split.click( | |
split_input_text, | |
inputs=[input_text, strip_input_lines, line_max_length], | |
outputs=[split_text] | |
) | |
btn_summary.click( | |
gen_summary, | |
inputs=[split_text, keyword_summary_prompt, line_max_length], | |
outputs=[summary] | |
) | |
btn_keyword.click( | |
gen_keyword_summary, | |
inputs=[split_text, keyword_prompt, summary_prompt, line_max_length], | |
outputs=[keyword_summary_prompt] | |
) | |