rodrigomasini commited on
Commit
2995eda
1 Parent(s): 3124ebe

Create app_v3.py

Browse files
Files changed (1) hide show
  1. app_v3.py +69 -0
app_v3.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, TextStreamer, pipeline, logging
2
+ from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
3
+ import time
4
+
5
+ model_name_or_path = "TheBloke/llama2_7b_chat_uncensored-GPTQ"
6
+ model_basename = "gptq_model-4bit-128g"
7
+
8
+ use_triton = False
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True, legacy=False)
11
+
12
+ model = AutoGPTQForCausalLM.from_quantized(model_name_or_path,
13
+ model_basename=model_basename,
14
+ use_safetensors=True,
15
+ trust_remote_code=True,
16
+ device="cuda:0",
17
+ use_triton=use_triton,
18
+ quantize_config=None)
19
+
20
+ """
21
+ To download from a specific branch, use the revision parameter, as in this example:
22
+
23
+ model = AutoGPTQForCausalLM.from_quantized(model_name_or_path,
24
+ revision="gptq-4bit-32g-actorder_True",
25
+ model_basename=model_basename,
26
+ use_safetensors=True,
27
+ trust_remote_code=True,
28
+ device="cuda:0",
29
+ quantize_config=None)
30
+ """
31
+
32
+ prompt = "Tell me about AI"
33
+ prompt_template=f'''### HUMAN:
34
+ {prompt}
35
+
36
+ ### RESPONSE:
37
+ '''
38
+ print("\n\n*** Generate:")
39
+ start_time = time.time()
40
+ input_ids = tokenizer(prompt_template, return_tensors='pt').input_ids.cuda()
41
+ streamer = TextStreamer(tokenizer)
42
+ # output = model.generate(inputs=input_ids, temperature=0.7, max_new_tokens=512)
43
+ # print(tokenizer.decode(output[0]))
44
+
45
+ _ = model.generate(inputs=input_ids, streamer=streamer, temperature=0.7, max_new_tokens=512)
46
+ print(f"Inference time: {time.time() - start_time:.4f} seconds")
47
+
48
+ # Inference can also be done using transformers' pipeline
49
+
50
+ # Prevent printing spurious transformers error when using pipeline with AutoGPTQ
51
+ logging.set_verbosity(logging.CRITICAL)
52
+
53
+ print("*** Pipeline:")
54
+ start_time = time.time()
55
+
56
+ pipe = pipeline(
57
+ "text-generation",
58
+ model=model,
59
+ tokenizer=tokenizer,
60
+ streamer=streamer,
61
+ max_new_tokens=512,
62
+ temperature=0.7,
63
+ top_p=0.95,
64
+ repetition_penalty=1.15
65
+ )
66
+
67
+ pipe(prompt_template)
68
+ #print(pipe(prompt_template)[0]['generated_text'])
69
+ print(f"Inference time: {time.time() - start_time:.4f} seconds")