rodrigomasini commited on
Commit
9ebdc85
1 Parent(s): 7843ac8

Update app_v3.py

Browse files
Files changed (1) hide show
  1. app_v3.py +20 -37
app_v3.py CHANGED
@@ -18,41 +18,24 @@ model = AutoGPTQForCausalLM.from_quantized(model_name_or_path,
18
  quantize_config=None)
19
 
20
 
21
- prompt = "Tell me about AI"
22
- prompt_template=f'''### HUMAN:
23
- {prompt}
24
-
25
- ### RESPONSE:
26
- '''
27
- print("\n\n*** Generate:")
28
- start_time = time.time()
29
- input_ids = tokenizer(prompt_template, return_tensors='pt').input_ids.cuda()
30
- streamer = TextStreamer(tokenizer)
31
- # output = model.generate(inputs=input_ids, temperature=0.7, max_new_tokens=512)
32
- # print(tokenizer.decode(output[0]))
33
-
34
- _ = model.generate(inputs=input_ids, streamer=streamer, temperature=0.7, max_new_tokens=512)
35
- print(f"Inference time: {time.time() - start_time:.4f} seconds")
36
-
37
- # Inference can also be done using transformers' pipeline
38
-
39
- # Prevent printing spurious transformers error when using pipeline with AutoGPTQ
40
- logging.set_verbosity(logging.CRITICAL)
41
-
42
- print("*** Pipeline:")
43
- start_time = time.time()
44
-
45
- pipe = pipeline(
46
- "text-generation",
47
- model=model,
48
- tokenizer=tokenizer,
49
- streamer=streamer,
50
- max_new_tokens=512,
51
- temperature=0.7,
52
- top_p=0.95,
53
- repetition_penalty=1.15
54
  )
55
-
56
- pipe(prompt_template)
57
- #print(pipe(prompt_template)[0]['generated_text'])
58
- print(f"Inference time: {time.time() - start_time:.4f} seconds")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  quantize_config=None)
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  )
22
+ user_input = st.text_input("Input a phrase")
23
+
24
+ prompt_template = f'USER: {user_input}\nASSISTANT:'
25
+
26
+ if st.button("Generate the prompt"):
27
+
28
+ inputs_ids = tokenizer(prompt_template, return_tensors='pt').input_ids.cuda()
29
+ streamer = TextStreamer(tokenizer)
30
+ pipe = pipeline(
31
+ "text-generation",
32
+ model=model,
33
+ tokenizer=tokenizer,
34
+ streamer=streamer,
35
+ max_new_tokens=512,
36
+ temperature=0.2,
37
+ top_p=0.95,
38
+ repetition_penalty=1.15
39
+ )
40
+ pipe(prompt_template)
41
+ st.write(pipe(prompt_template)[0]['generated_text'])