alibidaran commited on
Commit
fa5f463
1 Parent(s): f936973

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -0
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer
3
+ import gradio as gr
4
+ from gradio.themes.base import Base
5
+ from gradio.themes.utils import colors, fonts, sizes
6
+ from typing import Iterable
7
+
8
+ class SQLGEN(Base):
9
+ def __init__(
10
+ self,
11
+ *,
12
+ primary_hue: colors.Color | str = colors.stone,
13
+ secondary_hue: colors.Color | str = colors.green,
14
+ neutral_hue: colors.Color | str = colors.gray,
15
+ spacing_size: sizes.Size | str = sizes.spacing_md,
16
+ radius_size: sizes.Size | str = sizes.radius_md,
17
+ text_size: sizes.Size | str = sizes.text_lg,
18
+ font: fonts.Font
19
+ | str
20
+ | Iterable[fonts.Font | str] = (
21
+ fonts.GoogleFont("IBM Plex Mono"),
22
+ "ui-sans-serif",
23
+ "sans-serif",
24
+ ),
25
+ font_mono: fonts.Font
26
+ | str
27
+ | Iterable[fonts.Font | str] = (
28
+ fonts.GoogleFont("IBM Plex Mono"),
29
+ "ui-monospace",
30
+ "monospace",
31
+ ),
32
+ ):
33
+ super().__init__(
34
+ primary_hue=primary_hue,
35
+ secondary_hue=secondary_hue,
36
+ neutral_hue=neutral_hue,
37
+ spacing_size=spacing_size,
38
+ radius_size=radius_size,
39
+ text_size=text_size,
40
+ font=font,
41
+ font_mono=font_mono,
42
+ )
43
+
44
+
45
+
46
+ model_id = "alibidaran/Gemma2_SQLGEN"
47
+
48
+ bnb_config = BitsAndBytesConfig(
49
+ load_in_4bit=True,
50
+ bnb_4bit_quant_type="nf4",
51
+ bnb_4bit_compute_dtype=torch.bfloat16
52
+ )
53
+
54
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
55
+ model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0})
56
+ tokenizer.padding_side = 'right'
57
+
58
+
59
+ def generate_sql(query,context):
60
+ prompt = query
61
+ context=context
62
+ text=f"<s>##Question: {prompt} \n ##Context: {context} \n ##Answer:"
63
+ inputs=tokenizer(text,return_tensors='pt').to('cuda')
64
+ with torch.no_grad():
65
+ outputs=model.generate(**inputs,max_new_tokens=100,do_sample=True,top_p=0.99,top_k=10,temperature=0.5)
66
+ output_text=outputs[:, inputs.input_ids.shape[1]:]
67
+ output_text=tokenizer.decode(output_text[0], skip_special_tokens=True)
68
+ return output_text
69
+
70
+
71
+ interface=gr.Interface(generate_sql,['text','text'],gr.Code(),title='SQLGEN', theme=SQLGEN())
72
+
73
+ if __name__=='__main__':
74
+ interface.launch()