prakhardoneria commited on
Commit
5fa43e8
·
verified ·
1 Parent(s): 8a0b89f

Changed Structure

Browse files
Files changed (1) hide show
  1. app.py +21 -57
app.py CHANGED
@@ -2,71 +2,35 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
- # Load DialoGPT model and tokenizer
6
- model_name = "microsoft/DialoGPT-large"
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(model_name)
9
 
10
- # Respond function for Gradio interface
11
- def respond(
12
- message,
13
- history,
14
- system_message,
15
- max_tokens,
16
- temperature,
17
- top_p,
18
- ):
19
- # Initialize history if it is None
20
- if history is None:
21
- history = []
22
 
23
- try:
24
- # Format the chat history for the DialoGPT model using the 'messages' format
25
- conversation = [{"role": "system", "content": system_message}]
26
- for user_msg, bot_msg in history:
27
- if user_msg:
28
- conversation.append({"role": "user", "content": user_msg})
29
- if bot_msg:
30
- conversation.append({"role": "assistant", "content": bot_msg})
31
- conversation.append({"role": "user", "content": message})
32
 
33
- # Tokenize input and generate response
34
- inputs = tokenizer.encode(" ".join([msg["content"] for msg in conversation]), return_tensors="pt")
35
- outputs = model.generate(
36
- inputs,
37
- max_length=max_tokens,
38
- temperature=temperature,
39
- top_p=top_p,
40
- pad_token_id=tokenizer.eos_token_id,
41
- )
42
- response = tokenizer.decode(outputs[:, inputs.shape[-1] :][0], skip_special_tokens=True)
43
 
44
- # Append the new user-bot interaction to the history
45
- history.append((message, response))
46
 
47
- return response, history
48
- except Exception as e:
49
- # Return the actual error message if something goes wrong
50
- return f"Error occurred: {str(e)}", history
51
 
52
- # Gradio Chat Interface
53
  demo = gr.Interface(
54
  fn=respond,
55
- inputs=[
56
- gr.Textbox(label="Message"),
57
- gr.State(), # For history
58
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
59
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
60
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
61
- gr.Slider(
62
- minimum=0.1,
63
- maximum=1.0,
64
- value=0.95,
65
- step=0.05,
66
- label="Top-p (nucleus sampling)",
67
- ),
68
- ],
69
- outputs=["text", gr.State()], # Return the response and updated history
70
  )
71
 
72
  if __name__ == "__main__":
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
+ # Load the DialoGPT model and tokenizer
6
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
7
+ model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
 
8
 
9
+ # Respond function
10
+ def respond(message, chat_history=None):
11
+ if chat_history is None:
12
+ chat_history = []
 
 
 
 
 
 
 
 
13
 
14
+ # Encode the user input and append to the chat history
15
+ new_user_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors='pt')
16
+ bot_input_ids = torch.cat([torch.tensor(chat_history), new_user_input_ids], dim=-1) if chat_history else new_user_input_ids
 
 
 
 
 
 
17
 
18
+ # Generate the bot's response
19
+ chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
20
+ bot_message = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
 
 
 
 
 
 
 
21
 
22
+ # Update chat history
23
+ chat_history = chat_history_ids.tolist()
24
 
25
+ return bot_message, chat_history
 
 
 
26
 
27
+ # Gradio Interface
28
  demo = gr.Interface(
29
  fn=respond,
30
+ inputs=["text", gr.State()],
31
+ outputs=["text", gr.State()],
32
+ title="DialoGPT Chatbot",
33
+ description="A chatbot powered by Microsoft's DialoGPT.",
 
 
 
 
 
 
 
 
 
 
 
34
  )
35
 
36
  if __name__ == "__main__":