Sabbah13 commited on
Commit
f37f889
1 Parent(s): a24ffe8

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +176 -0
main.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Body, Request, File, UploadFile, BackgroundTasks, Form, Depends
2
+ from pydantic import BaseModel, constr
3
+ from huggingface_hub import HfApi
4
+ from fastapi.security import OAuth2PasswordBearer
5
+ from typing import Optional, Dict
6
+ import httpx
7
+ import os
8
+ import asyncio
9
+ import logging
10
+ from gigiachat_requests import get_access_token, get_completion_from_gigachat, get_number_of_tokens, process_transcribation_with_gigachat
11
+ from openai_requests import get_completion_from_openai, process_transcribation_with_assistant
12
+
13
+ repo_id = os.getenv('HF_SPACE_NAME')
14
+
15
+ api = HfApi()
16
+
17
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
18
+
19
+ # Настройка логгера
20
+ logger = logging.getLogger(__name__)
21
+ logger.setLevel(logging.INFO)
22
+
23
+ # Создание обработчика для вывода в консоль
24
+ console_handler = logging.StreamHandler()
25
+ console_handler.setLevel(logging.INFO)
26
+ formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
27
+ console_handler.setFormatter(formatter)
28
+ logger.addHandler(console_handler)
29
+
30
+ app = FastAPI()
31
+
32
+ # Определяем модель данных для запроса
33
+ class TranscriptRequest(BaseModel):
34
+ transcript: str
35
+ json_transcript: Dict
36
+ final_url: str
37
+ llm: str
38
+ base_prompt: str
39
+ proccess_prompt: str
40
+ need_proccessing: bool
41
+
42
+ class FinalRequest(BaseModel):
43
+ transcript: str
44
+ proccessed_transcript: str
45
+ summary: str
46
+ json_transcript: Dict
47
+
48
+ def verify_token(token: str = Depends(oauth2_scheme)):
49
+ if token != os.environ.get("AUTH_TOKEN"):
50
+ raise HTTPException(status_code=401, detail="Invalid token")
51
+
52
+ # Главная страница с текстом "server is running"
53
+ @app.get("/")
54
+ async def read_root():
55
+ return {"text": "server is running"}
56
+
57
+ @app.post("/test")
58
+ def echo_text(text_request: FinalRequest):
59
+
60
+ logger.info(f"Final endpoint received transcript! Transcript: {text_request.transcript}.\n Proccessed transcript: {text_request.proccessed_transcript}.\n Json transcript: {text_request.json_transcript} Summary: {text_request.summary}")
61
+
62
+ return {"transcript": text_request.transcript, "summary": text_request.summary}
63
+
64
+ async def send_to_llm(transcript_request: TranscriptRequest):
65
+ transcript = transcript_request.transcript
66
+ base_prompt = transcript_request.base_prompt
67
+ llm = transcript_request.llm
68
+ need_proccessing = transcript_request.need_proccessing
69
+ processing_prompt = transcript_request.proccess_prompt
70
+ proccessed_transcript = ''
71
+
72
+ if (llm == 'GigaChat'):
73
+ access_token = get_access_token()
74
+ logger.info('Got access token for GigaChat')
75
+
76
+
77
+ if (need_proccessing):
78
+ logger.info('Strarting proccessing')
79
+ if (llm == 'GigaChat'):
80
+ number_of_tokens = get_number_of_tokens(transcript, access_token)
81
+ logger.info('Количество токенов в транскрибации: ' + str(number_of_tokens))
82
+ proccessed_transcript = process_transcribation_with_gigachat(processing_prompt, transcript, number_of_tokens + 1000, access_token)
83
+ logger.info('Proccessed transcript: ' + transcript)
84
+
85
+ elif (llm == 'ChatGPT'):
86
+ proccessed_transcript = process_transcribation_with_assistant(processing_prompt, transcript)
87
+ logger.info('Proccessed transcript: ' + transcript)
88
+
89
+ logger.info('Strarting summarization')
90
+ transcript_for_summary = proccessed_transcript if need_proccessing else transcript
91
+ # Получение саммари
92
+ if (llm == 'GigaChat'):
93
+ summary_answer = get_completion_from_gigachat(base_prompt + transcript_for_summary, 1024, access_token)
94
+ elif (llm == 'ChatGPT'):
95
+ summary_answer = get_completion_from_openai(base_prompt + transcript_for_summary, 1024)
96
+
97
+ async with httpx.AsyncClient() as client:
98
+ response = await client.post(transcript_request.final_url, json={"transcript": transcript, "json_transcript": transcript_request.json_transcript, 'proccessed_transcript': proccessed_transcript, "summary": summary_answer})
99
+
100
+ @app.post("/send_transcript")
101
+ async def send_transcript(transcript_request: TranscriptRequest, background_tasks: BackgroundTasks = BackgroundTasks(), token: str = Depends(verify_token)):
102
+
103
+ logger.info('Got transcript, starting summarization. Your llm is ' + transcript_request.llm)
104
+
105
+ background_tasks.add_task(send_to_llm, transcript_request)
106
+
107
+ return {"message": "Transcript received, sending to llm"}
108
+
109
+ async def restart_and_check_space(repo_id):
110
+ # Перезапускаем пространство
111
+ logger.info('Restarting space')
112
+ api.restart_space(repo_id=repo_id)
113
+
114
+ # Проверяем статус каждые 15 секунд
115
+ while True:
116
+ run_time = api.get_space_runtime(repo_id=repo_id)
117
+ if run_time.stage == 'RUNNING':
118
+ logger.info('Transcribation space is running, sending file')
119
+ break
120
+ else:
121
+ logger.info('Waiting for space to be running...')
122
+ await asyncio.sleep(15)
123
+
124
+ async def send_file_to_transcribation(url: str, file: UploadFile, llm: str, base_prompt: str, proccess_prompt: str, need_proccessing: bool, max_speakes: int, min_speakers: int):
125
+
126
+ run_time = api.get_space_runtime(repo_id=repo_id)
127
+ if run_time.stage != 'RUNNING':
128
+ await restart_and_check_space(repo_id)
129
+ else:
130
+ logger.info('Transcribation space is running, sending file')
131
+
132
+ async with httpx.AsyncClient() as client:
133
+
134
+ # Считываем содержимое файла в байты
135
+ file_content = await file.read()
136
+
137
+ files = {
138
+ 'file': ('file', file_content, file.content_type),
139
+ 'transcript_url': (None, os.getenv('HF_TRANSCRIPT_URL')),
140
+ }
141
+ data = {
142
+ 'final_url': str(url),
143
+ 'llm': str(llm),
144
+ 'base_prompt': str(base_prompt),
145
+ 'proccess_prompt': str(proccess_prompt),
146
+ 'need_proccessing': need_proccessing,
147
+ 'max_speakers': max_speakes,
148
+ 'min_speakers': min_speakers,
149
+ }
150
+ headers = {
151
+ 'Authorization': f'Bearer {os.environ.get("HF_TOKEN")}'
152
+ }
153
+
154
+ response = await client.post(os.getenv('HF_TRANSCRIBATION_SPACE_URL'), headers=headers, files=files, data=data)
155
+
156
+ logger.info(f"Status code: {response.status_code}, Data: {response.text}")
157
+
158
+
159
+ @app.post("/upload")
160
+ def upload_file(file: UploadFile = File(...),
161
+ url: str = Form(...),
162
+ llm: str = Form(...),
163
+ base_prompt: str = Form(...),
164
+ proccess_prompt: str = Form(...),
165
+ need_proccessing: bool = Form(...),
166
+ max_speakers: Optional[int] = Form(None),
167
+ min_speakers: Optional[int] = Form(None),
168
+ background_tasks: BackgroundTasks = BackgroundTasks(),
169
+ token: str = Depends(verify_token)):
170
+
171
+ if (llm != 'GigaChat' and llm != 'ChatGPT'):
172
+ raise HTTPException(status_code=422, detail='Llm must be GigaChat or ChatGPT')
173
+
174
+ background_tasks.add_task(send_file_to_transcribation, url, file, llm, base_prompt, proccess_prompt, need_proccessing, max_speakers, min_speakers)
175
+
176
+ return {"message": "Got file with name: " + file.filename + ', After proccessing, transcript will be sent to ' + url }