KJX123 commited on
Commit
0b46936
1 Parent(s): dc817ae

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -0
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+
5
+ # Load the fine-tuned model and tokenizer
6
+ def load_model_and_tokenizer(model_dir):
7
+ try:
8
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
9
+ model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
10
+ if torch.cuda.is_available():
11
+ model = model.to('cuda')
12
+ else:
13
+ model = model.to('cpu')
14
+ except Exception as e:
15
+ print(f"Error loading model: {e}")
16
+ return None, None
17
+ return model, tokenizer
18
+
19
+ # Update the model directory path to use forward slashes
20
+ model_dir = "KJX123/Llama2-7b-finetune"
21
+ model, tokenizer = load_model_and_tokenizer(model_dir)
22
+
23
+ # Function to generate code
24
+ def generate_code(query):
25
+ if not model or not tokenizer:
26
+ return "Model or tokenizer not loaded properly."
27
+
28
+ prompt = f"Query: {query}\nGitHub Code:\nYouTube Code:"
29
+ inputs = tokenizer(prompt, return_tensors='pt')
30
+
31
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
32
+ inputs = inputs.to(device)
33
+
34
+ outputs = model.generate(inputs['input_ids'], max_length=600, num_return_sequences=1)
35
+ generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
36
+ return generated_code
37
+
38
+ def gradio_interface(query):
39
+ code = generate_code(query)
40
+ return code
41
+
42
+ # Define Gradio app layout
43
+ iface = gr.Interface(
44
+ fn=gradio_interface,
45
+ inputs=gr.Textbox(lines=2, placeholder="Enter your query here..."),
46
+ outputs=gr.Textbox(lines=20, placeholder="Generated code will appear here..."),
47
+ title="Code Generator",
48
+ description="Enter a programming task or query to generate code using the fine-tuned model."
49
+ )
50
+
51
+ if __name__ == "__main__":
52
+ iface.launch()