rodrigomasini commited on
Commit
009c471
1 Parent(s): e15f802

Create app_v1.py

Browse files
Files changed (1) hide show
  1. app_v1.py +45 -0
app_v1.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer
3
+ from auto_gptq import AutoGPTQForCausalLM
4
+
5
+ import os
6
+ import threading
7
+
8
+ cwd = os.getcwd()
9
+ cachedir = cwd+'/cache'
10
+
11
+ # Assuming the rest of your setup code is correct and `local_folder` is properly set up
12
+
13
+ class QuantizedModel:
14
+ def __init__(self, model_dir):
15
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=False)
16
+ self.model = AutoGPTQForCausalLM.from_quantized(
17
+ model_dir,
18
+ use_safetensors=True,
19
+ strict=False,
20
+ device="cuda:0",
21
+ use_triton=False
22
+ )
23
+
24
+ def generate(self, prompt, max_new_tokens=512, temperature=0.1, top_p=0.95, repetition_penalty=1.15):
25
+ inputs = self.tokenizer(prompt, return_tensors="pt")
26
+ outputs = self.model.generate(
27
+ input_ids=inputs['input_ids'].to("cuda:0"),
28
+ attention_mask=inputs['attention_mask'].to("cuda:0"),
29
+ max_length=max_new_tokens + inputs['input_ids'].size(-1),
30
+ temperature=temperature,
31
+ top_p=top_p,
32
+ repetition_penalty=repetition_penalty
33
+ )
34
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
35
+
36
+ quantized_model = QuantizedModel(local_folder)
37
+
38
+ user_input = st.text_input("Input a phrase")
39
+
40
+ prompt_template = f'USER: {user_input}\nASSISTANT:'
41
+
42
+ # Generate output when the "Generate" button is pressed
43
+ if st.button("Generate the prompt"):
44
+ output = quantized_model.generate(prompt_template)
45
+ st.text_area("Prompt", value=output)