File size: 4,140 Bytes
8edd1fa
cd77e73
8edd1fa
1ee7467
 
 
cd77e73
 
1d9da65
 
 
 
 
 
 
 
 
 
 
 
94cbad2
 
 
 
 
b6bcc82
1ee7467
122da82
1ee7467
94cbad2
cd77e73
122da82
cd77e73
 
 
1ee7467
 
cd77e73
 
94cbad2
cd77e73
 
122da82
94cbad2
cd77e73
94cbad2
cd77e73
1ee7467
94cbad2
 
1ee7467
0f0f232
1ee7467
0f0f232
 
 
 
 
 
 
 
 
1ee7467
 
 
 
 
 
0f0f232
 
 
 
 
 
 
 
 
 
 
 
1d9da65
fdec892
 
 
1d9da65
1ee7467
 
 
 
 
 
 
 
 
 
 
 
8edd1fa
b43bca2
8edd1fa
 
1ee7467
8edd1fa
cd77e73
8edd1fa
3d9cfb9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr
import openai
import os
import requests
from transformers import BertTokenizer, BertForSequenceClassification
import torch
import faiss
import numpy as np
import json

def clean_payload(payload):
    # Remove "data:" prefix and clean newline characters
    cleaned_payload = payload.lstrip("data:").rstrip("\n")
    try:
        json_payload = json.loads(cleaned_payload)
    except json.JSONDecodeError as e:
        print(f"JSON decoding error: {e}")
        json_payload = None
    return json_payload

from huggingface_hub import InferenceClient  # Keeping Hugging Face Client as requested

def clean_payload(payload):
    cleaned_payload = payload.lstrip("data:").rstrip("\n")
    return json.loads(cleaned_payload)

# API Keys and Org ID
openai.api_key = os.getenv("OPENAI_API_KEY")
openai.organization = os.getenv("OPENAI_ORG_ID")
serper_api_key = os.getenv("SERPER_API_KEY")  # SERPER API key from environment variable

# Load PubMedBERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
model = BertForSequenceClassification.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", num_labels=2)

# FAISS setup for vector search
dimension = 768
index = faiss.IndexFlatL2(dimension)

# Embed text (PubMedBERT)
def embed_text(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
    outputs = model(**inputs, output_hidden_states=True)
    return outputs.hidden_states[-1].mean(dim=1).detach().numpy()

# Handle FDA query
def handle_fda_query(query):
    inputs = tokenizer(query, return_tensors="pt", padding="max_length", truncation=True, max_length=512)
    logits = model(**inputs).logits
    return "FDA Query Processed: Contains regulatory info." if torch.argmax(logits, dim=1).item() == 1 else "FDA Query Processed: General."

# Function to enhance info via GPT-4o-mini
def enhance_with_gpt4o(fda_response):
    try:
        response = openai.ChatCompletion.create(
            model="gpt-4o-mini",  # Correct model
            messages=[{"role": "system", "content": "You are an expert FDA assistant."}, {"role": "user", "content": f"Enhance this FDA info: {fda_response}"}],
            max_tokens=150
        )
        return response['choices'][0]['message']['content']
    except Exception as e:
        return f"Error: {str(e)}"

def respond(message, system_message, max_tokens, temperature, top_p):
    try:
        # First retrieve info via PubMedBERT
        fda_response = handle_fda_query(message)
        
        # Stream the enhanced response via GPT-4o-mini using the correct OpenAI API
        response = openai.ChatCompletion.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": "You are an expert FDA assistant."},
                {"role": "user", "content": f"Enhance this FDA info: {fda_response}"}
            ],
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p
        )

        enhanced_response = ""
        for chat_message in response['choices']:
            payload = chat_message['message']['content']
            enhanced_response += payload

        # Return both the PubMedBERT result and the enhanced version
        return f"Original Info from PubMedBERT: {fda_response}\n\nEnhanced Info via GPT-4o-mini: {enhanced_response}"

    except Exception as e:
        return f"Error: {str(e)}"

# Gradio Interface
demo = gr.Interface(
    fn=respond,
    inputs=[
        gr.Textbox(label="Enter your FDA query", placeholder="Ask Ferris2.0 anything FDA-related."),
        gr.Textbox(value="You are Ferris2.0, the most advanced FDA Regulatory Assistant.", label="System message"),  
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
    ],
    outputs="text",
)

if __name__ == "__main__":
    demo.launch()