Spaces:
Running
Running
import base64 | |
import io | |
import os | |
from fastapi import Depends, Query, File, UploadFile, HTTPException | |
from fastapi.templating import Jinja2Templates | |
from sqlalchemy import select | |
from sqlalchemy.ext.asyncio import AsyncSession | |
from fastapi.requests import Request | |
from project.bot import bot_router | |
from project.bot.records import FileVoiceRecord | |
from project.bot.utils import generate_ai_report, encode_file_to_base64, transcript_audio_from_base64, \ | |
generate_image_description, compress_and_save_image | |
from project.bot.models import AudioRecord, Folder, ImageRecord, Report | |
from project.bot.schemas import AudioImageIDsSchema, DataBytesSchema, RecordSchema, AudioRecordBase64Schema | |
from project.database import get_async_session | |
from project.users import current_user | |
from project.users.models import User | |
template = Jinja2Templates(directory='project/bot/templates') | |
async def main(request: Request): | |
return template.TemplateResponse("home.html", {'request': request}) | |
async def create_report(ids: AudioImageIDsSchema, session: AsyncSession = Depends(get_async_session), | |
user: User = Depends(current_user)): | |
reports = await session.execute(select(Report.id).where(Report.user_id == user.id)) | |
reports = reports.scalars().all() | |
if len(reports) >= 12: | |
raise HTTPException(status_code=403, detail="You cannot request more than 12 reports.") | |
audio_transcription = await session.execute( | |
select(AudioRecord.transcription).where(AudioRecord.id.in_(ids.audio_ids))) | |
audio_transcription = audio_transcription.scalars().all() | |
image_descriptions = await session.execute( | |
select(ImageRecord.transcription).where(ImageRecord.id.in_(ids.image_ids))) | |
image_descriptions = image_descriptions.scalars().all() | |
report = await generate_ai_report(audio_transcription + image_descriptions, ids.language) | |
report.user = user | |
report_content = report.content | |
session.add(report) | |
await session.commit() | |
return report_content | |
async def get_folders(session: AsyncSession = Depends(get_async_session), user: User = Depends(current_user)): | |
folders = await session.execute(select(Folder.name).where(Folder.user_id == user.id)) | |
return folders.scalars().all() | |
async def delete_record(record: RecordSchema, session: AsyncSession = Depends(get_async_session)): | |
if record.type == 'audio': | |
record = await session.execute(select(AudioRecord).where(AudioRecord.id == record.record_id)) | |
record = record.scalars().first() | |
file_path = record.audio_path | |
await session.delete(record) | |
await session.commit() | |
try: | |
os.remove(file_path) | |
except FileNotFoundError: | |
pass | |
elif record.type == 'image': | |
record = await session.execute(select(ImageRecord).where(ImageRecord.id == record.record_id)) | |
record = record.scalars().first() | |
await session.delete(record) | |
await session.commit() | |
async def get_folder_records(folder: int = Query(...), session: AsyncSession = Depends(get_async_session)): | |
records = await session.execute(select(AudioRecord).where(AudioRecord.folder_id == folder)) | |
return records.scalars().all() | |
async def upload_file_record(data: DataBytesSchema, session: AsyncSession = Depends(get_async_session)): | |
file_record = FileVoiceRecord(session, data.data_bytes) | |
await file_record.save() | |
async def get_record_transcription(record_id: int, model: str, session: AsyncSession = Depends(get_async_session)): | |
if model == 'audio': | |
record = await session.execute(select(AudioRecord.transcription).where(AudioRecord.id == record_id)) | |
elif model == 'image': | |
record = await session.execute(select(ImageRecord.transcription).where(ImageRecord.id == record_id)) | |
return record.scalars().first() | |
async def play_record(record_id: int, model: str, session: AsyncSession = Depends(get_async_session)): | |
if model == 'audio': | |
file_path = await session.execute(select(AudioRecord.audio_path).where(AudioRecord.id == record_id)) | |
else: | |
file_path = await session.execute(select(ImageRecord.image_path).where(ImageRecord.id == record_id)) | |
file_path = file_path.scalars().first() | |
base64_content = await encode_file_to_base64(file_path) | |
return base64_content | |
async def transcript_record(data: AudioRecordBase64Schema, session: AsyncSession = Depends(get_async_session), | |
user: User = Depends(current_user)): | |
transcription, file_path = await transcript_audio_from_base64(data.audio) | |
record = AudioRecord() | |
record.audio_path = file_path | |
record.transcription = transcription | |
record.folder = await user.get_base_folder(session=session) | |
session.add(record) | |
await session.commit() | |
return transcription | |
async def get_records(session: AsyncSession = Depends(get_async_session), | |
user: User = Depends(current_user) | |
): | |
folder = await user.get_base_folder(session=session) | |
if not folder: | |
folder = Folder() | |
folder.name = 'Default' | |
folder.owner = user | |
folder.user_id = user.id | |
session.add(folder) | |
await session.commit() | |
records = await session.execute(select(AudioRecord).where(AudioRecord.folder == folder)) | |
records = records.scalars().all() | |
images = await session.execute(select(ImageRecord).where(ImageRecord.folder == folder)) | |
images = images.scalars().all() | |
data = {'audio': records, 'images': images} | |
return data | |
async def upload_image(image: UploadFile = File(...), session: AsyncSession = Depends(get_async_session), | |
user: User = Depends(current_user)): | |
image_content = await image.read() | |
file_format = image.filename.split('.')[-1].lower() | |
if file_format == 'jpg': | |
file_format = 'jpeg' | |
base64_image = base64.b64encode(image_content).decode() | |
image_description = await generate_image_description(base64_image, file_format) | |
image_path = compress_and_save_image(image_content) | |
image_object = ImageRecord() | |
image_object.transcription = image_description | |
image_object.image_path = image_path | |
image_object.folder = await user.get_base_folder(session=session) | |
session.add(image_object) | |
await session.commit() | |
return image_description | |