rodrigomasini commited on
Commit
c38f54f
1 Parent(s): ca98966

Update app_v1.py

Browse files
Files changed (1) hide show
  1. app_v1.py +18 -16
app_v1.py CHANGED
@@ -6,10 +6,10 @@ import os
6
  import torch
7
 
8
  # Clear up some memory
9
- torch.cuda.empty_cache()
10
 
11
  # Try reducing the number of threads PyTorch uses
12
- torch.set_num_threads(1)
13
 
14
  cwd = os.getcwd()
15
  cachedir = cwd + '/cache'
@@ -53,20 +53,22 @@ model = AutoGPTQForCausalLM.from_quantized(
53
  quantize_config=quantize_config
54
  )
55
 
56
- user_input = st.text_input("Input a phrase")
57
 
58
- prompt_template = f'USER: {user_input}\nASSISTANT:'
 
 
59
 
60
  # Generate output when the "Generate" button is pressed
61
- if st.button("Generate the prompt"):
62
- inputs = tokenizer(prompt_template, return_tensors="pt")
63
- outputs = model.generate(
64
- input_ids=inputs.input_ids.to("cuda:0"),
65
- attention_mask=inputs.attention_mask.to("cuda:0"),
66
- max_length=512 + inputs.input_ids.size(-1),
67
- temperature=0.1,
68
- top_p=0.95,
69
- repetition_penalty=1.15
70
- )
71
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
72
- st.text_area("Prompt", value=generated_text)
 
6
  import torch
7
 
8
  # Clear up some memory
9
+ # torch.cuda.empty_cache()
10
 
11
  # Try reducing the number of threads PyTorch uses
12
+ # torch.set_num_threads(1)
13
 
14
  cwd = os.getcwd()
15
  cachedir = cwd + '/cache'
 
53
  quantize_config=quantize_config
54
  )
55
 
56
+ st.write(model.hf_device_map)
57
 
58
+ #user_input = st.text_input("Input a phrase")
59
+
60
+ #prompt_template = f'USER: {user_input}\nASSISTANT:'
61
 
62
  # Generate output when the "Generate" button is pressed
63
+ #if st.button("Generate the prompt"):
64
+ # inputs = tokenizer(prompt_template, return_tensors="pt")
65
+ # outputs = model.generate(
66
+ # input_ids=inputs.input_ids.to("cuda:0"),
67
+ # attention_mask=inputs.attention_mask.to("cuda:0"),
68
+ # max_length=512 + inputs.input_ids.size(-1),
69
+ # temperature=0.1,
70
+ # top_p=0.95,
71
+ # repetition_penalty=1.15
72
+ # )
73
+ # generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
74
+ # st.text_area("Prompt", value=generated_text)