VinitT commited on
Commit
a4cea97
·
verified ·
1 Parent(s): ced0893

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -20
app.py CHANGED
@@ -1,43 +1,41 @@
1
  from peft import PeftModel
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
4
 
5
  # Load the base model, fine-tuned model, and tokenizer
6
  try:
7
- # Base model and fine-tuned model names
8
- base_model_name = "meta-llama/Meta-Llama-3-8B"
9
- fine_tuned_model_name = "VinitT/Sanskrit-llama"
10
-
11
  print("Loading base model...")
12
- base_model = AutoModelForCausalLM.from_pretrained(base_model_name)
13
-
14
  print("Loading fine-tuned model...")
15
- model = PeftModel.from_pretrained(base_model, fine_tuned_model_name)
16
-
17
  print("Loading tokenizer...")
18
- try:
19
- tokenizer = AutoTokenizer.from_pretrained(base_model_name)
20
- except Exception as e:
21
- print(f"Failed to load tokenizer for {base_model_name}: {e}")
22
- print("Falling back to a generic tokenizer...")
23
- tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
24
 
25
  except Exception as e:
26
- print(f"Error loading the model or tokenizer: {e}")
27
- tokenizer = None # Ensure the tokenizer is defined even if loading fails
28
 
29
  # Function to generate text using the model
30
  def generate_text(input_text):
31
- if tokenizer is None:
32
- return "Tokenizer failed to load. Please check your model configuration."
33
  try:
34
  # Tokenize the input
35
  inputs = tokenizer(input_text, return_tensors="pt")
36
-
37
  # Generate response
38
  outputs = model.generate(**inputs, max_length=200, num_return_sequences=1)
39
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
40
-
41
  return response
42
  except Exception as e:
43
  return f"Error during generation: {e}"
@@ -45,6 +43,7 @@ def generate_text(input_text):
45
  # Create Gradio Interface
46
  with gr.Blocks() as demo:
47
  gr.Markdown("## Sanskrit Text Generation with Fine-Tuned LLaMA")
 
48
 
49
  input_text = gr.Textbox(label="Enter your prompt in Sanskrit", placeholder="Type your text here...")
50
  output_text = gr.Textbox(label="Generated Text", interactive=False)
 
1
  from peft import PeftModel
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import gradio as gr
4
+ import os
5
+
6
+ # Fetch Hugging Face token from environment variables
7
+ token = os.getenv("HUGGING_FACE_HUB_TOKEN")
8
+ if not token:
9
+ raise ValueError("Hugging Face token not found. Please add it as a secret in your Hugging Face Space.")
10
+
11
+ # Base model and fine-tuned model names
12
+ base_model_name = "meta-llama/Meta-Llama-3-8B"
13
+ fine_tuned_model_name = "VinitT/Sanskrit-llama"
14
 
15
  # Load the base model, fine-tuned model, and tokenizer
16
  try:
 
 
 
 
17
  print("Loading base model...")
18
+ base_model = AutoModelForCausalLM.from_pretrained(base_model_name, use_auth_token=token)
19
+
20
  print("Loading fine-tuned model...")
21
+ model = PeftModel.from_pretrained(base_model, fine_tuned_model_name, use_auth_token=token)
22
+
23
  print("Loading tokenizer...")
24
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name, use_auth_token=token)
 
 
 
 
 
25
 
26
  except Exception as e:
27
+ raise RuntimeError(f"Error loading the model or tokenizer: {e}")
 
28
 
29
  # Function to generate text using the model
30
  def generate_text(input_text):
 
 
31
  try:
32
  # Tokenize the input
33
  inputs = tokenizer(input_text, return_tensors="pt")
34
+
35
  # Generate response
36
  outputs = model.generate(**inputs, max_length=200, num_return_sequences=1)
37
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
38
+
39
  return response
40
  except Exception as e:
41
  return f"Error during generation: {e}"
 
43
  # Create Gradio Interface
44
  with gr.Blocks() as demo:
45
  gr.Markdown("## Sanskrit Text Generation with Fine-Tuned LLaMA")
46
+ gr.Markdown("### Enter your Sanskrit prompt below and generate text using the fine-tuned LLaMA model.")
47
 
48
  input_text = gr.Textbox(label="Enter your prompt in Sanskrit", placeholder="Type your text here...")
49
  output_text = gr.Textbox(label="Generated Text", interactive=False)