from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoModelForSeq2SeqLM, AutoTokenizer import torch import gradio as gr app = FastAPI() # Initialize model and tokenizer model_name = "google/flan-t5-large" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) class Query(BaseModel): inputs: str @app.post("/") async def generate(query: Query): try: # Tokenize input inputs = tokenizer(query.inputs, return_tensors="pt", max_length=512, truncation=True) # Generate response outputs = model.generate( inputs.input_ids, max_length=512, num_beams=4, temperature=0.7, top_p=0.9, repetition_penalty=1.2, early_stopping=True ) # Decode response response = tokenizer.decode(outputs[0], skip_special_tokens=True) return {"generated_text": response} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # Gradio interface def generate_text(prompt): query = Query(inputs=prompt) response = generate(query) return response["generated_text"] iface = gr.Interface( fn=generate_text, inputs=gr.Textbox(lines=2, placeholder="Enter your text here..."), outputs="text", title="Medical Assistant", description="Ask me anything about medical topics!" ) # Mount the Gradio app app = gr.mount_gradio_app(app, iface, path="/") if __name__ == "__main__": import train # This will start the training process