Spaces:
Paused
Paused
import os | |
# import torch | |
import transformers | |
import gradio as gr | |
# from huggingface_hub import hf_hub_download | |
from huggingface_hub import snapshot_download | |
import safetensors | |
# from transformer_engine.pytorch import fp8_autocast | |
from transformers import AutoModelForCausalLM, BitsAndBytesConfig | |
quantization_config = BitsAndBytesConfig(load_in_4bit=True) | |
import torch | |
print(f"Is CUDA available: {torch.cuda.is_available()}") | |
# True | |
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
# Tesla T4 | |
# os.environ['HF_HOME'] = '/data/.huggingface' | |
auth_token = os.environ['HF_TOKEN'] or True | |
model_id = "fcastanedo/energy_v1" | |
files_to_download = [ | |
"config.json", | |
"model-00001-of-00030.safetensors", | |
"model-00002-of-00030.safetensors", | |
"model-00003-of-00030.safetensors", | |
"model-00004-of-00030.safetensors", | |
"model-00005-of-00030.safetensors", | |
"model-00006-of-00030.safetensors", | |
"model-00007-of-00030.safetensors", | |
"model-00008-of-00030.safetensors", | |
"model-00009-of-00030.safetensors", | |
"model-00010-of-00030.safetensors", | |
"model-00011-of-00030.safetensors", | |
"model-00012-of-00030.safetensors", | |
"model-00013-of-00030.safetensors", | |
"model-00014-of-00030.safetensors", | |
"model-00015-of-00030.safetensors", | |
"model-00016-of-00030.safetensors", | |
"model-00017-of-00030.safetensors", | |
"model-00018-of-00030.safetensors", | |
"model-00019-of-00030.safetensors", | |
"model-00020-of-00030.safetensors", | |
"model-00021-of-00030.safetensors", | |
"model-00022-of-00030.safetensors", | |
"model-00023-of-00030.safetensors", | |
"model-00024-of-00030.safetensors", | |
"model-00025-of-00030.safetensors", | |
"model-00026-of-00030.safetensors", | |
"model-00027-of-00030.safetensors", | |
"model-00028-of-00030.safetensors", | |
"model-00029-of-00030.safetensors", | |
"model-00030-of-00030.safetensors", | |
"special_tokens_map.json", | |
"tokenizer.json", | |
"tokenizer_config.json" | |
] | |
''' | |
# Directory to store downloaded files | |
model_dir = f"./{model_id}" | |
os.makedirs(model_dir, exist_ok=True) | |
''' | |
# Use /data for persistent storage | |
model_dir = f"/data/{model_id}" | |
os.makedirs(model_dir, exist_ok=True) | |
# snapshot_download(repo_id=model_id, ignore_patterns="*.bin", token=auth_token) | |
# ''' | |
# Download model to persistent storage (if not already there) | |
if not os.path.exists(model_dir) or not os.listdir(model_dir): | |
print("Downloading Weights") | |
snapshot_download(repo_id=model_id, local_dir=model_dir, ignore_patterns="*.bin", token=auth_token) | |
snapshot_download(repo_id=model_id, local_dir=model_dir, ignore_patterns=["*.safetensors", "*.json"], token=auth_token) | |
# ''' | |
# snapshot_download(repo_id=model_id, local_dir=model_dir, ignore_patterns=["*.safetensors", "*.json"], token=auth_token) | |
''' | |
# Download each file | |
for file in files_to_download: | |
hf_hub_download(repo_id=model_id, filename=file, local_dir=model_dir, token=auth_token) | |
''' | |
''' | |
with fp8_autocast(): # Enables FP8 computations | |
model = transformers.AutoModelForCausalLM.from_pretrained( | |
model_dir, | |
# state_dict=state_dict, | |
torch_dtype=torch.float16 # Load in FP16 first, then convert | |
) | |
''' | |
# Load the model manually from local files | |
# model = transformers.AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.int8) | |
# model = transformers.AutoModelForCausalLM.from_pretrained(model_dir, load_in_4bit=True) | |
model = transformers.AutoModelForCausalLM.from_pretrained(model_dir, quantization_config=quantization_config) | |
tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir) | |
''' | |
model.to(dtype=torch.float16) # Load as FP16 first | |
model = model.half() # Convert to FP8-like (closest possible) | |
''' | |
# Create pipeline with manually loaded model & tokenizer | |
pipeline = transformers.pipeline( | |
"text-generation", | |
model=model, | |
# model_kwargs={"torch_dtype": torch.int8}, | |
tokenizer=tokenizer, | |
# device=3, | |
decive="cuda", | |
# device_map="auto", | |
) | |
''' | |
pipeline = transformers.pipeline( | |
"text-generation", | |
model=model_id, | |
model_kwargs={"torch_dtype": torch.bfloat16}, | |
token=auth_token, | |
device=3 | |
# device_map="auto", | |
) | |
''' | |
messages = [ | |
{ | |
"role":"system", | |
"content":"You are an expert in Oil, Gas, and Petroleum for certifications like Petroleum Engineering Certificate (SPE). You will be provided Multiple Choice Questions. Select the correct response out of the four choices." | |
}, | |
{ | |
"role":"user", | |
"content":"Who are you?" | |
} | |
] | |
prompt = pipeline.tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True, | |
) | |
terminators = [ | |
pipeline.tokenizer.eos_token_id, | |
pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>") | |
] | |
outputs = pipeline( | |
prompt, | |
max_new_tokens = 256, | |
eos_token_id = terminators, | |
do_sample = True, | |
temperature = 0.6, | |
top_p = 0.9, | |
) | |
def chat_function(message, history, system_prompt, max_new_tokens, temperature): | |
messages = [{"role":"system","content":system_prompt}, | |
{"role":"user", "content":message}] | |
prompt = pipeline.tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True,) | |
terminators = [ | |
pipeline.tokenizer.eos_token_id, | |
pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")] | |
outputs = pipeline( | |
prompt, | |
max_new_tokens = max_new_tokens, | |
eos_token_id = terminators, | |
do_sample = True, | |
temperature = temperature + 0.1, | |
top_p = 0.9,) | |
return outputs[0]["generated_text"][len(prompt):] | |
gr.ChatInterface( | |
chat_function, | |
textbox=gr.Textbox(placeholder="Enter message here", container=False, scale = 7), | |
chatbot=gr.Chatbot(height=400), | |
additional_inputs=[ | |
gr.Textbox("You are helpful AI", label="System Prompt"), | |
gr.Slider(500,4000, label="Max New Tokens"), | |
gr.Slider(0,1, label="Temperature") | |
] | |
).launch() | |