import streamlit as st from transformers import AutoTokenizer from auto_gptq import AutoGPTQForCausalLM import torch import subprocess # Function to get memory info def get_gpu_memory(): try: result = subprocess.check_output(["nvidia-smi", "--query-gpu=memory.free,memory.total", "--format=csv,nounits,noheader"], text=True) memory_info = [x.split(',') for x in result.strip().split('\n')] memory_info = [{"free": int(x[0].strip()), "total": int(x[1].strip())} for x in memory_info] except FileNotFoundError: memory_info = [{"free": "N/A", "total": "N/A"}] return memory_info # Display GPU memory information gpu_memory = get_gpu_memory() st.write(f"GPU Memory Info: {gpu_memory}") # Define pretrained model directory pretrained_model_dir = "FPHam/Jackson_The_Formalizer_V2_13b_GPTQ" # Check if CUDA is available and get the device device = "cuda:0" if torch.cuda.is_available() else "cpu" # Before allocating or loading the model, clear up memory if CUDA is available if device == "cuda:0": torch.cuda.empty_cache() # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) # Attempt to load the model, catch any OOM errors try: model = AutoGPTQForCausalLM.from_quantized( pretrained_model_dir, model_basename="Jackson2-4bit-128g-GPTQ", use_safetensors=True, device=device, max_memory={0: "10GIB"} ) except RuntimeError as e: if 'CUDA out of memory' in str(e): st.error("CUDA out of memory. Try reducing the model size or input length.") st.stop() else: raise e # User input for the model user_input = st.text_input("Input a phrase") # Generate button if st.button("Generate the prompt"): try: prompt_template = f'USER: {user_input}\nASSISTANT:' inputs = tokenizer(prompt_template, return_tensors='pt', max_length=512, truncation=True, padding='max_length') output = model.generate(**inputs) st.markdown(f"**Generated Text:**\n{tokenizer.decode(output[0])}") except RuntimeError as e: if 'CUDA out of memory' in str(e): st.error("CUDA out of memory during generation. Try reducing the input length.") else: raise e