HexaGrid / utils /lora_details.py
Surn's picture
Thumbnail update #1
ced6a2a
# utils/lora_details.py
import gradio as gr
from utils.constants import LORA_DETAILS
def upd_prompt_notes_by_index(lora_index):
"""
Updates the prompt_notes_label with the notes from LORAS based on index.
Args:
lora_index (int): The index of the selected LoRA model.
Returns:
gr.update: Updated Gradio label component with the notes.
"""
try:
if LORAS[lora_index]:
notes = LORAS[lora_index].get('notes', None)
if notes is None:
trigger_word = LORAS[lora_index].get('trigger_word', "")
trigger_position = LORAS[lora_index].get('trigger_position', "")
notes = f"{trigger_position} '{trigger_word}' in prompt"
except IndexError:
notes = "Enter Prompt description of your image, \nusing models without LoRa may take a 30 minutes."
return gr.update(value=notes)
def get_trigger_words_by_index(lora_index):
"""
Retrieves the trigger words from LORAS for the specified index.
Args:
lora_index (int): The index of the selected LoRA model.
Returns:
str: The trigger words associated with the model, or an empty string if not found.
"""
try:
trigger_words = LORAS[lora_index].get('trigger_word', "")
except IndexError:
trigger_words = ""
return trigger_words
def upd_prompt_notes(model_textbox_value):
"""
Updates the prompt_notes_label with the notes from LORA_DETAILS.
Args:
model_textbox_value (str): The name of the LoRA model.
Returns:
gr.update: Updated Gradio label component with the notes.
"""
notes = ""
if model_textbox_value in LORA_DETAILS:
lora_detail_list = LORA_DETAILS[model_textbox_value]
for item in lora_detail_list:
if 'notes' in item:
notes = item['notes']
break
else:
notes = "Enter Prompt description of your image, \nusing models without LoRa may take a 30 minutes."
return gr.update(value=notes)
def get_trigger_words(model_textbox_value):
"""
Retrieves the trigger words from constants.LORA_DETAILS for the specified model.
Args:
model_textbox_value (str): The name of the LoRA model.
Returns:
str: The trigger words associated with the model, or a default message if not found.
"""
trigger_words = ""
if model_textbox_value in LORA_DETAILS:
lora_detail_list = LORA_DETAILS[model_textbox_value]
for item in lora_detail_list:
if 'trigger_words' in item:
trigger_words = item['trigger_words']
break
else:
trigger_words = ""
return trigger_words
def upd_trigger_words(model_textbox_value):
"""
Updates the trigger_words_label with the trigger words from LORA_DETAILS.
Args:
model_textbox_value (str): The name of the LoRA model.
Returns:
gr.update: Updated Gradio label component with the trigger words.
"""
trigger_words = get_trigger_words(model_textbox_value)
return gr.update(value=trigger_words)
def approximate_token_count(prompt):
"""
Approximates the number of tokens in a prompt based on word count.
Parameters:
prompt (str): The text prompt.
Returns:
int: The approximate number of tokens.
"""
words = prompt.split()
# Average tokens per word (can vary based on language and model)
tokens_per_word = 1.3
return int(len(words) * tokens_per_word)
def split_prompt_by_tokens(prompt, token_number):
words = prompt.split()
# Average tokens per word (can vary based on language and model)
tokens_per_word = 1.3
return ' '.join(words[:int(tokens_per_word * token_number)]), ' '.join(words[int(tokens_per_word * token_number):])
# Split prompt precisely by token count
import tiktoken
def split_prompt_precisely(prompt, max_tokens=77, model="gpt-3.5-turbo"):
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
tokens = encoding.encode(prompt)
if len(tokens) <= max_tokens:
return prompt, ""
# Find the split point
split_point = max_tokens
split_tokens = tokens[:split_point]
remaining_tokens = tokens[split_point:]
split_prompt = encoding.decode(split_tokens)
remaining_prompt = encoding.decode(remaining_tokens)
return split_prompt, remaining_prompt