LLM_FinetuneR / handler.py
Diksha2001's picture
Update handler.py
be3b9e4 verified
import logging
import runpod
import os
import shutil
import uuid
import json
import time
import subprocess
from typing import Dict, Any
from azure.storage.blob import BlobServiceClient
# Modify logging configuration to print to console and file
logging.basicConfig(
level=logging.DEBUG, # Change to DEBUG to capture more detailed logs
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler() # Add stream handler to print to console
]
)
def get_azure_connection_string():
"""Get Azure connection string from environment variable"""
conn_string = "DefaultEndpointsProtocol=https;AccountName=transcribedblobstorage;AccountKey=1Z7yKPP5DLbxnoHdh7NmHgwg3dFLaDiYHUELdid7dzfzR6/DvkZnnzpJ30lrXIMhtD5GYKo+71jP+AStC1TEvA==;EndpointSuffix=core.windows.net"
if not conn_string:
raise ValueError("Azure Storage connection string not found in environment variables")
return conn_string
def upload_file(file_path: str) -> str:
if not os.path.isfile(file_path):
raise FileNotFoundError(f"The specified file does not exist: {file_path}")
container_name = "saasdev"
connection_string = get_azure_connection_string()
blob_service_client = BlobServiceClient.from_connection_string(connection_string)
container_client = blob_service_client.get_container_client(container_name)
# Generate a unique blob name using UUID
blob_name = f"{uuid.uuid4()}.pdf"
with open(file_path, 'rb') as file:
blob_client = container_client.get_blob_client(blob_name)
blob_client.upload_blob(file)
logging.info(f"File uploaded to blob: {blob_name}")
return blob_name
def download_blob(blob_name: str, download_file_path: str) -> None:
"""Download a file from Azure Blob Storage"""
container_name = "saasdev"
connection_string = get_azure_connection_string()
blob_service_client = BlobServiceClient.from_connection_string(connection_string)
container_client = blob_service_client.get_container_client(container_name)
blob_client = container_client.get_blob_client(blob_name)
os.makedirs(os.path.dirname(download_file_path), exist_ok=True)
with open(download_file_path, "wb") as download_file:
download_stream = blob_client.download_blob()
download_file.write(download_stream.readall())
logging.info(f"Blob '{blob_name}' downloaded to '{download_file_path}'")
def clean_directory(directory: str) -> None:
"""Clean up a directory by removing all files and subdirectories"""
if os.path.exists(directory):
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.remove(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
logging.error(f'Failed to delete {file_path}. Reason: {e}')
def handler(job: Dict[str, Any]) -> Dict[str, Any]:
start_time = time.time()
logging.info("Handler function started")
job_input = job.get('input', {})
required_fields = ['pdf_file', 'system_prompt', 'model_name', 'max_step', 'learning_rate', 'epochs']
missing_fields = [field for field in required_fields if field not in job_input]
if missing_fields:
return {
"status": "error",
"error": f"Missing required fields: {', '.join(missing_fields)}"
}
work_dir = os.path.abspath(f"/tmp/work_{str(uuid.uuid4())}")
try:
os.makedirs(work_dir, exist_ok=True)
logging.info(f"Working directory created: {work_dir}")
# Upload PDF to Blob
pdf_path = job_input['pdf_file']
generated_blob_name = upload_file(pdf_path)
logging.info(f"PDF uploaded with blob name: {generated_blob_name}")
# Download the uploaded PDF using the internally generated blob name
downloaded_path = os.path.join(work_dir, "Downloaded_PDF.pdf")
download_blob(generated_blob_name, downloaded_path)
logging.info(f"PDF downloaded to: {downloaded_path}")
# Save pipeline input as JSON
pipeline_input_path = os.path.join(work_dir, "pipeline_input.json")
pipeline_input = {
"pdf_file": downloaded_path,
"system_prompt": job_input['system_prompt'],
"model_name": job_input['model_name'],
"max_step": job_input['max_step'],
"learning_rate": job_input['learning_rate'],
"epochs": job_input['epochs']
}
with open(pipeline_input_path, 'w') as f:
json.dump(pipeline_input, f)
# Run fine-tuning and evaluation
return run_pipeline_and_evaluate(pipeline_input_path, job_input['model_name'], start_time)
except Exception as e:
error_message = f"Job failed after {time.time() - start_time:.2f} seconds: {str(e)}"
logging.error(error_message)
return {
"status": "error",
"error": error_message
}
finally:
try:
clean_directory(work_dir)
os.rmdir(work_dir)
except Exception as e:
logging.error(f"Failed to clean up working directory: {str(e)}")
def run_pipeline_and_evaluate(pipeline_input_path: str, model_name: str, start_time: float) -> Dict[str, Any]:
try:
# Suppress logging output
logging.getLogger().setLevel(logging.ERROR)
# Read the pipeline input file
with open(pipeline_input_path, 'r') as f:
pipeline_input = json.load(f)
# Convert the input to a JSON string for passing as an argument
pipeline_input_str = json.dumps(pipeline_input)
# Run fine-tuning pipeline with JSON string as argument
# logging.info(f"Running pipeline with input: {pipeline_input_str[:100]}...")
finetuning_result = subprocess.run(
['python3', 'Finetuning_Pipeline.py', pipeline_input_str],
capture_output=True,
text=True,
check=True
)
logging.info("Fine-tuning completed successfully")
# Run evaluation
evaluation_input = json.dumps({"model_name": model_name})
result = subprocess.run(
['python3', 'VLLM_evaluation.py', evaluation_input],
capture_output=True,
text=True,
check=True
)
try:
# Extract JSON part from stdout
output_lines = result.stdout.splitlines()
for line in reversed(output_lines):
try:
evaluation_results = json.loads(line)
if "average_semantic_score" in evaluation_results and "average_bleu_score" in evaluation_results:
break
except json.JSONDecodeError:
continue
else:
# If no valid JSON is found, fall back to raw output
evaluation_results = {"raw_output": result.stdout}
except Exception as e:
evaluation_results = {"error": f"Failed to process evaluation output: {str(e)}"}
# Print only the JSON part to stdout for capturing in Gradio
print(json.dumps({
"status": "success",
"model_name": f"PharynxAI/{model_name}",
"processing_time": time.time() - start_time,
"evaluation_results": evaluation_results
}))
return {
"status": "success",
"model_name": f"PharynxAI/{model_name}",
"processing_time": time.time() - start_time,
"evaluation_results": evaluation_results
}
except subprocess.CalledProcessError as e:
error_message = f"Pipeline process failed: {e.stderr}"
logging.error(error_message)
return {
"status": "error",
"error": error_message,
# "stdout": e.stdout,
# "stderr": e.stderr
}
except Exception as e:
error_message = f"Pipeline execution failed: {str(e)}"
logging.error(error_message)
return {
"status": "error",
"error": error_message
}
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
runpod.serverless.start({"handler": handler})