Spaces:
Sleeping
Sleeping
File size: 4,277 Bytes
8edd1fa cd77e73 8edd1fa 1ee7467 cd77e73 1d9da65 94cbad2 1d9da65 94cbad2 b6bcc82 1ee7467 122da82 1ee7467 94cbad2 cd77e73 122da82 cd77e73 1ee7467 cd77e73 94cbad2 cd77e73 122da82 94cbad2 cd77e73 94cbad2 cd77e73 1ee7467 94cbad2 1ee7467 0f0f232 1ee7467 0f0f232 1ee7467 0f0f232 1d9da65 1ee7467 8edd1fa b43bca2 8edd1fa 1ee7467 8edd1fa cd77e73 8edd1fa 3d9cfb9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import gradio as gr
import openai
import os
import requests
from transformers import BertTokenizer, BertForSequenceClassification
import torch
import faiss
import numpy as np
import json
<<<<<<< Updated upstream
def clean_payload(payload):
# Remove "data:" prefix and clean newline characters
cleaned_payload = payload.lstrip("data:").rstrip("\n")
try:
json_payload = json.loads(cleaned_payload)
except json.JSONDecodeError as e:
print(f"JSON decoding error: {e}")
json_payload = None
return json_payload
=======
from huggingface_hub import InferenceClient # Keeping Hugging Face Client as requested
def clean_payload(payload):
cleaned_payload = payload.lstrip("data:").rstrip("\n")
return json.loads(cleaned_payload)
>>>>>>> Stashed changes
# API Keys and Org ID
openai.api_key = os.getenv("OPENAI_API_KEY")
openai.organization = os.getenv("OPENAI_ORG_ID")
serper_api_key = os.getenv("SERPER_API_KEY") # SERPER API key from environment variable
# Load PubMedBERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
model = BertForSequenceClassification.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", num_labels=2)
# FAISS setup for vector search
dimension = 768
index = faiss.IndexFlatL2(dimension)
# Embed text (PubMedBERT)
def embed_text(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
outputs = model(**inputs, output_hidden_states=True)
return outputs.hidden_states[-1].mean(dim=1).detach().numpy()
# Handle FDA query
def handle_fda_query(query):
inputs = tokenizer(query, return_tensors="pt", padding="max_length", truncation=True, max_length=512)
logits = model(**inputs).logits
return "FDA Query Processed: Contains regulatory info." if torch.argmax(logits, dim=1).item() == 1 else "FDA Query Processed: General."
# Function to enhance info via GPT-4o-mini
def enhance_with_gpt4o(fda_response):
try:
response = openai.ChatCompletion.create(
model="gpt-4o-mini", # Correct model
messages=[{"role": "system", "content": "You are an expert FDA assistant."}, {"role": "user", "content": f"Enhance this FDA info: {fda_response}"}],
max_tokens=150
)
return response['choices'][0]['message']['content']
except Exception as e:
return f"Error: {str(e)}"
def respond(message, system_message, max_tokens, temperature, top_p):
try:
# First retrieve info via PubMedBERT
fda_response = handle_fda_query(message)
# Stream the enhanced response via GPT-4o-mini using the correct OpenAI API
response = openai.ChatCompletion.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "You are an expert FDA assistant."},
{"role": "user", "content": f"Enhance this FDA info: {fda_response}"}
],
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p
)
enhanced_response = ""
for chat_message in client.chat_completion(...):
payload = json.loads(chat_message.lstrip("data:").rstrip("\n"))
enhanced_response += payload["content"] # Or however the payload structure works
# Return both the PubMedBERT result and the enhanced version
return f"Original Info from PubMedBERT: {fda_response}\n\nEnhanced Info via GPT-4o-mini: {enhanced_response}"
except Exception as e:
return f"Error: {str(e)}"
# Gradio Interface
demo = gr.Interface(
fn=respond,
inputs=[
gr.Textbox(label="Enter your FDA query", placeholder="Ask Ferris2.0 anything FDA-related."),
gr.Textbox(value="You are Ferris2.0, the most advanced FDA Regulatory Assistant.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
],
outputs="text",
)
if __name__ == "__main__":
demo.launch() |