|
|
|
|
|
import gradio as gr |
|
from utils.constants import LORA_DETAILS |
|
def upd_prompt_notes_by_index(lora_index): |
|
""" |
|
Updates the prompt_notes_label with the notes from LORAS based on index. |
|
|
|
Args: |
|
lora_index (int): The index of the selected LoRA model. |
|
|
|
Returns: |
|
gr.update: Updated Gradio label component with the notes. |
|
""" |
|
try: |
|
if LORAS[lora_index]: |
|
notes = LORAS[lora_index].get('notes', None) |
|
if notes is None: |
|
trigger_word = LORAS[lora_index].get('trigger_word', "") |
|
trigger_position = LORAS[lora_index].get('trigger_position', "") |
|
notes = f"{trigger_position} '{trigger_word}' in prompt" |
|
except IndexError: |
|
notes = "Enter Prompt description of your image, \nusing models without LoRa may take a 30 minutes." |
|
return gr.update(value=notes) |
|
|
|
def get_trigger_words_by_index(lora_index): |
|
""" |
|
Retrieves the trigger words from LORAS for the specified index. |
|
|
|
Args: |
|
lora_index (int): The index of the selected LoRA model. |
|
|
|
Returns: |
|
str: The trigger words associated with the model, or an empty string if not found. |
|
""" |
|
try: |
|
trigger_words = LORAS[lora_index].get('trigger_word', "") |
|
except IndexError: |
|
trigger_words = "" |
|
return trigger_words |
|
|
|
def upd_prompt_notes(model_textbox_value): |
|
""" |
|
Updates the prompt_notes_label with the notes from LORA_DETAILS. |
|
|
|
Args: |
|
model_textbox_value (str): The name of the LoRA model. |
|
|
|
Returns: |
|
gr.update: Updated Gradio label component with the notes. |
|
""" |
|
notes = "" |
|
if model_textbox_value in LORA_DETAILS: |
|
lora_detail_list = LORA_DETAILS[model_textbox_value] |
|
for item in lora_detail_list: |
|
if 'notes' in item: |
|
notes = item['notes'] |
|
break |
|
else: |
|
notes = "Enter Prompt description of your image, \nusing models without LoRa may take a 30 minutes." |
|
return gr.update(value=notes) |
|
|
|
def get_trigger_words(model_textbox_value): |
|
""" |
|
Retrieves the trigger words from constants.LORA_DETAILS for the specified model. |
|
|
|
Args: |
|
model_textbox_value (str): The name of the LoRA model. |
|
|
|
Returns: |
|
str: The trigger words associated with the model, or a default message if not found. |
|
""" |
|
trigger_words = "" |
|
if model_textbox_value in LORA_DETAILS: |
|
lora_detail_list = LORA_DETAILS[model_textbox_value] |
|
for item in lora_detail_list: |
|
if 'trigger_words' in item: |
|
trigger_words = item['trigger_words'] |
|
break |
|
else: |
|
trigger_words = "" |
|
return trigger_words |
|
|
|
def upd_trigger_words(model_textbox_value): |
|
""" |
|
Updates the trigger_words_label with the trigger words from LORA_DETAILS. |
|
|
|
Args: |
|
model_textbox_value (str): The name of the LoRA model. |
|
|
|
Returns: |
|
gr.update: Updated Gradio label component with the trigger words. |
|
""" |
|
trigger_words = get_trigger_words(model_textbox_value) |
|
return gr.update(value=trigger_words) |
|
|
|
def approximate_token_count(prompt): |
|
""" |
|
Approximates the number of tokens in a prompt based on word count. |
|
|
|
Parameters: |
|
prompt (str): The text prompt. |
|
|
|
Returns: |
|
int: The approximate number of tokens. |
|
""" |
|
words = prompt.split() |
|
|
|
tokens_per_word = 1.3 |
|
return int(len(words) * tokens_per_word) |
|
|
|
def split_prompt_by_tokens(prompt, token_number): |
|
words = prompt.split() |
|
|
|
tokens_per_word = 1.3 |
|
return ' '.join(words[:int(tokens_per_word * token_number)]), ' '.join(words[int(tokens_per_word * token_number):]) |
|
|
|
|
|
import tiktoken |
|
|
|
def split_prompt_precisely(prompt, max_tokens=77, model="gpt-3.5-turbo"): |
|
try: |
|
encoding = tiktoken.encoding_for_model(model) |
|
except KeyError: |
|
encoding = tiktoken.get_encoding("cl100k_base") |
|
|
|
tokens = encoding.encode(prompt) |
|
|
|
if len(tokens) <= max_tokens: |
|
return prompt, "" |
|
|
|
|
|
split_point = max_tokens |
|
split_tokens = tokens[:split_point] |
|
remaining_tokens = tokens[split_point:] |
|
|
|
split_prompt = encoding.decode(split_tokens) |
|
remaining_prompt = encoding.decode(remaining_tokens) |
|
|
|
return split_prompt, remaining_prompt |
|
|