Spaces:
Paused
Paused
import json | |
from sentence_transformers import SentenceTransformer, util | |
import nltk | |
import os | |
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction | |
import time | |
import logging | |
import subprocess | |
import requests | |
import sys | |
import json | |
# Set the GLOO_SOCKET_IFNAME environment variable | |
# os.environ["GLOO_SOCKET_IFNAME"] = "lo" | |
# Simplified logging | |
logging.basicConfig(level=logging.INFO, format='%(message)s') | |
# Load pre-trained models for evaluation | |
semantic_model = SentenceTransformer('all-MiniLM-L6-v2') | |
# Download necessary NLTK resources | |
nltk.download('punkt', quiet=True) | |
def load_input_data(): | |
"""Load input data from command line arguments.""" | |
try: | |
# Check if input is provided via command-line argument | |
if len(sys.argv) > 1: | |
return json.loads(sys.argv[1]) | |
else: | |
logging.error("No input data provided") | |
sys.exit(1) | |
except json.JSONDecodeError as e: | |
logging.error(f"Failed to decode JSON input: {e}") | |
sys.exit(1) | |
def wait_for_server(max_attempts=30): | |
"""Wait for the vLLM server to become available.""" | |
url = "http://localhost:8000/health" | |
for attempt in range(max_attempts): | |
try: | |
response = requests.get(url, timeout=5) | |
if response.status_code == 200: | |
logging.info("vLLM server is ready!") | |
return True | |
except requests.exceptions.RequestException: | |
time.sleep(2) | |
logging.error("vLLM server failed to start") | |
return False | |
def start_vllm_server(model_name): | |
cmd = [ | |
"vllm", | |
"serve", | |
f"PharynxAI/{model_name}", | |
"--gpu_memory_utilization=0.98", | |
"--max_model_len=4096", | |
"--enable-chunked-prefill=False", | |
"--num_scheduler_steps=2" | |
] | |
logging.info(f"Starting vLLM server: {' '.join(cmd)}") | |
server_process = subprocess.Popen(cmd) | |
if not wait_for_server(): | |
server_process.terminate() | |
raise Exception("Server failed to start") | |
return server_process | |
def evaluate_semantic_similarity(expected_response, model_response, semantic_model): | |
"""Evaluate semantic similarity using Sentence-BERT.""" | |
expected_embedding = semantic_model.encode(expected_response, convert_to_tensor=True) | |
model_embedding = semantic_model.encode(model_response, convert_to_tensor=True) | |
similarity_score = util.pytorch_cos_sim(expected_embedding, model_embedding) | |
return similarity_score.item() | |
def evaluate_bleu(expected_response, model_response): | |
"""Evaluate BLEU score using NLTK's sentence_bleu.""" | |
expected_tokens = nltk.word_tokenize(expected_response.lower()) | |
model_tokens = nltk.word_tokenize(model_response.lower()) | |
smoothing_function = SmoothingFunction().method1 | |
bleu_score = sentence_bleu([expected_tokens], model_tokens, smoothing_function=smoothing_function) | |
return bleu_score | |
def query_vllm_server(prompt, model_name): | |
"""Query the vLLM server.""" | |
url = "http://localhost:8000/v1/chat/completions" | |
headers = {"Content-Type": "application/json"} | |
data = { | |
"model": f"PharynxAI/{model_name}", | |
"messages": [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": prompt} | |
] | |
} | |
try: | |
response = requests.post(url, headers=headers, json=data, timeout=300) | |
response.raise_for_status() | |
return response.json() | |
except Exception as e: | |
logging.error(f"Server query failed: {e}") | |
raise | |
def evaluate_model(data, model_name, semantic_model): | |
"""Evaluate the model using the provided data.""" | |
semantic_scores = [] | |
bleu_scores = [] | |
for entry in data: | |
prompt = entry['prompt'] | |
expected_response = entry['response'] | |
try: | |
# Query the vLLM server | |
response = query_vllm_server(prompt, model_name) | |
# Extract model's response | |
if 'choices' not in response or not response['choices']: | |
logging.error(f"No choices returned for prompt: {prompt}") | |
continue | |
model_response = response['choices'][0]['message']['content'] | |
# Evaluate scores | |
semantic_score = evaluate_semantic_similarity(expected_response, model_response, semantic_model) | |
semantic_scores.append(semantic_score) | |
bleu_score = evaluate_bleu(expected_response, model_response) | |
bleu_scores.append(bleu_score) | |
except Exception as e: | |
logging.error(f"Error processing entry: {e}") | |
continue | |
# Calculate average scores | |
avg_semantic_score = sum(semantic_scores) / len(semantic_scores) if semantic_scores else 0 | |
avg_bleu_score = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0 | |
# Create results dictionary | |
evaluation_results = { | |
'average_semantic_score': avg_semantic_score, | |
'average_bleu_score': avg_bleu_score | |
} | |
# Print JSON directly to stdout for capture | |
print(json.dumps(evaluation_results)) | |
return evaluation_results | |
def main(): | |
# Load input data | |
input_data = load_input_data() | |
model_name = input_data["model_name"] | |
server_process = None | |
try: | |
# Load dataset | |
with open('output_json.json', 'r') as f: | |
data = json.load(f) | |
# Start vLLM server | |
server_process = start_vllm_server(model_name) | |
# Run evaluation | |
evaluate_model(data, model_name, semantic_model) | |
except Exception as e: | |
logging.error(f"Evaluation failed: {e}") | |
sys.exit(1) | |
finally: | |
# Cleanup: terminate the server process | |
if server_process: | |
server_process.terminate() | |
try: | |
server_process.wait(timeout=5) | |
except subprocess.TimeoutExpired: | |
server_process.kill() | |
if __name__ == "__main__": | |
main() |