Spaces:
Sleeping
Sleeping
import gradio as gr | |
import openai | |
import os | |
import requests | |
from transformers import BertTokenizer, BertForSequenceClassification | |
import torch | |
import faiss | |
import numpy as np | |
import json | |
<<<<<<< Updated upstream | |
def clean_payload(payload): | |
# Remove "data:" prefix and clean newline characters | |
cleaned_payload = payload.lstrip("data:").rstrip("\n") | |
try: | |
json_payload = json.loads(cleaned_payload) | |
except json.JSONDecodeError as e: | |
print(f"JSON decoding error: {e}") | |
json_payload = None | |
return json_payload | |
======= | |
from huggingface_hub import InferenceClient # Keeping Hugging Face Client as requested | |
def clean_payload(payload): | |
cleaned_payload = payload.lstrip("data:").rstrip("\n") | |
return json.loads(cleaned_payload) | |
>>>>>>> Stashed changes | |
# API Keys and Org ID | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
openai.organization = os.getenv("OPENAI_ORG_ID") | |
serper_api_key = os.getenv("SERPER_API_KEY") # SERPER API key from environment variable | |
# Load PubMedBERT tokenizer and model | |
tokenizer = BertTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract") | |
model = BertForSequenceClassification.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", num_labels=2) | |
# FAISS setup for vector search | |
dimension = 768 | |
index = faiss.IndexFlatL2(dimension) | |
# Embed text (PubMedBERT) | |
def embed_text(text): | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512) | |
outputs = model(**inputs, output_hidden_states=True) | |
return outputs.hidden_states[-1].mean(dim=1).detach().numpy() | |
# Handle FDA query | |
def handle_fda_query(query): | |
inputs = tokenizer(query, return_tensors="pt", padding="max_length", truncation=True, max_length=512) | |
logits = model(**inputs).logits | |
return "FDA Query Processed: Contains regulatory info." if torch.argmax(logits, dim=1).item() == 1 else "FDA Query Processed: General." | |
# Function to enhance info via GPT-4o-mini | |
def enhance_with_gpt4o(fda_response): | |
try: | |
response = openai.ChatCompletion.create( | |
model="gpt-4o-mini", # Correct model | |
messages=[{"role": "system", "content": "You are an expert FDA assistant."}, {"role": "user", "content": f"Enhance this FDA info: {fda_response}"}], | |
max_tokens=150 | |
) | |
return response['choices'][0]['message']['content'] | |
except Exception as e: | |
return f"Error: {str(e)}" | |
def respond(message, system_message, max_tokens, temperature, top_p): | |
try: | |
# First retrieve info via PubMedBERT | |
fda_response = handle_fda_query(message) | |
# Stream the enhanced response via GPT-4o-mini using the correct OpenAI API | |
response = openai.ChatCompletion.create( | |
model="gpt-4o-mini", | |
messages=[ | |
{"role": "system", "content": "You are an expert FDA assistant."}, | |
{"role": "user", "content": f"Enhance this FDA info: {fda_response}"} | |
], | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p | |
) | |
enhanced_response = "" | |
for chat_message in client.chat_completion(...): | |
payload = json.loads(chat_message.lstrip("data:").rstrip("\n")) | |
enhanced_response += payload["content"] # Or however the payload structure works | |
# Return both the PubMedBERT result and the enhanced version | |
return f"Original Info from PubMedBERT: {fda_response}\n\nEnhanced Info via GPT-4o-mini: {enhanced_response}" | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Gradio Interface | |
demo = gr.Interface( | |
fn=respond, | |
inputs=[ | |
gr.Textbox(label="Enter your FDA query", placeholder="Ask Ferris2.0 anything FDA-related."), | |
gr.Textbox(value="You are Ferris2.0, the most advanced FDA Regulatory Assistant.", label="System message"), | |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)") | |
], | |
outputs="text", | |
) | |
if __name__ == "__main__": | |
demo.launch() |