Spaces:
Runtime error
Runtime error
from peft import PeftModel | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import gradio as gr | |
import os | |
# Fetch Hugging Face token from environment variables | |
token = os.getenv("HUGGING_FACE_HUB_TOKEN") | |
if not token: | |
raise ValueError("Hugging Face token not found. Please add it as a secret in your Hugging Face Space.") | |
# Base model and fine-tuned model names | |
base_model_name = "meta-llama/Meta-Llama-3-8B" | |
fine_tuned_model_name = "VinitT/Sanskrit-llama" | |
# Load the base model, fine-tuned model, and tokenizer | |
try: | |
print("Loading base model...") | |
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, use_auth_token=token) | |
print("Loading fine-tuned model...") | |
model = PeftModel.from_pretrained(base_model, fine_tuned_model_name, use_auth_token=token) | |
print("Loading tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained(base_model_name, use_auth_token=token) | |
except Exception as e: | |
raise RuntimeError(f"Error loading the model or tokenizer: {e}") | |
# Function to generate text using the model | |
def generate_text(input_text): | |
try: | |
# Tokenize the input | |
inputs = tokenizer(input_text, return_tensors="pt") | |
# Generate response | |
outputs = model.generate(**inputs, max_length=200, num_return_sequences=1) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response | |
except Exception as e: | |
return f"Error during generation: {e}" | |
# Create Gradio Interface | |
with gr.Blocks() as demo: | |
gr.Markdown("## Sanskrit Text Generation with Fine-Tuned LLaMA") | |
gr.Markdown("### Enter your Sanskrit prompt below and generate text using the fine-tuned LLaMA model.") | |
input_text = gr.Textbox(label="Enter your prompt in Sanskrit", placeholder="Type your text here...") | |
output_text = gr.Textbox(label="Generated Text", interactive=False) | |
generate_btn = gr.Button("Generate") | |
generate_btn.click(fn=generate_text, inputs=input_text, outputs=output_text) | |
# Run the Gradio app | |
if __name__ == "__main__": | |
demo.launch() | |