import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer import gradio as gr from gradio.themes.base import Base from gradio.themes.utils import colors, fonts, sizes from typing import Iterable class SQLGEN(Base): def __init__( self, *, primary_hue: colors.Color | str = colors.stone, secondary_hue: colors.Color | str = colors.green, neutral_hue: colors.Color | str = colors.gray, spacing_size: sizes.Size | str = sizes.spacing_md, radius_size: sizes.Size | str = sizes.radius_md, text_size: sizes.Size | str = sizes.text_lg, font: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("IBM Plex Mono"), "ui-sans-serif", "sans-serif", ), font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace", ), ): super().__init__( primary_hue=primary_hue, secondary_hue=secondary_hue, neutral_hue=neutral_hue, spacing_size=spacing_size, radius_size=radius_size, text_size=text_size, font=font, font_mono=font_mono, ) model_id = "alibidaran/Gemma2_SQLGEN" #bnb_config for GPU usage #bnb_config = BitsAndBytesConfig( # load_in_4bit=True, # bnb_4bit_quant_type="nf4", # bnb_4bit_compute_dtype=torch.bfloat16 #) tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id, device_map='auto') tokenizer.padding_side = 'right' def generate_sql(query,context): prompt = query context=context text=f"##Question: {prompt} \n ##Context: {context} \n ##Answer:" inputs=tokenizer(text,return_tensors='pt').to('cuda') with torch.no_grad(): outputs=model.generate(**inputs,max_new_tokens=100,do_sample=True,top_p=0.99,top_k=10,temperature=0.5) output_text=outputs[:, inputs.input_ids.shape[1]:] output_text=tokenizer.decode(output_text[0], skip_special_tokens=True) return output_text interface=gr.Interface(generate_sql,['text','text'],gr.Code(),title='SQLGEN', theme=SQLGEN()) if __name__=='__main__': interface.launch()