File size: 4,513 Bytes
650c805 ced6a2a 650c805 6ef117e 650c805 6ef117e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
# utils/lora_details.py
import gradio as gr
from utils.constants import LORA_DETAILS
def upd_prompt_notes_by_index(lora_index):
"""
Updates the prompt_notes_label with the notes from LORAS based on index.
Args:
lora_index (int): The index of the selected LoRA model.
Returns:
gr.update: Updated Gradio label component with the notes.
"""
try:
if LORAS[lora_index]:
notes = LORAS[lora_index].get('notes', None)
if notes is None:
trigger_word = LORAS[lora_index].get('trigger_word', "")
trigger_position = LORAS[lora_index].get('trigger_position', "")
notes = f"{trigger_position} '{trigger_word}' in prompt"
except IndexError:
notes = "Enter Prompt description of your image, \nusing models without LoRa may take a 30 minutes."
return gr.update(value=notes)
def get_trigger_words_by_index(lora_index):
"""
Retrieves the trigger words from LORAS for the specified index.
Args:
lora_index (int): The index of the selected LoRA model.
Returns:
str: The trigger words associated with the model, or an empty string if not found.
"""
try:
trigger_words = LORAS[lora_index].get('trigger_word', "")
except IndexError:
trigger_words = ""
return trigger_words
def upd_prompt_notes(model_textbox_value):
"""
Updates the prompt_notes_label with the notes from LORA_DETAILS.
Args:
model_textbox_value (str): The name of the LoRA model.
Returns:
gr.update: Updated Gradio label component with the notes.
"""
notes = ""
if model_textbox_value in LORA_DETAILS:
lora_detail_list = LORA_DETAILS[model_textbox_value]
for item in lora_detail_list:
if 'notes' in item:
notes = item['notes']
break
else:
notes = "Enter Prompt description of your image, \nusing models without LoRa may take a 30 minutes."
return gr.update(value=notes)
def get_trigger_words(model_textbox_value):
"""
Retrieves the trigger words from constants.LORA_DETAILS for the specified model.
Args:
model_textbox_value (str): The name of the LoRA model.
Returns:
str: The trigger words associated with the model, or a default message if not found.
"""
trigger_words = ""
if model_textbox_value in LORA_DETAILS:
lora_detail_list = LORA_DETAILS[model_textbox_value]
for item in lora_detail_list:
if 'trigger_words' in item:
trigger_words = item['trigger_words']
break
else:
trigger_words = ""
return trigger_words
def upd_trigger_words(model_textbox_value):
"""
Updates the trigger_words_label with the trigger words from LORA_DETAILS.
Args:
model_textbox_value (str): The name of the LoRA model.
Returns:
gr.update: Updated Gradio label component with the trigger words.
"""
trigger_words = get_trigger_words(model_textbox_value)
return gr.update(value=trigger_words)
def approximate_token_count(prompt):
"""
Approximates the number of tokens in a prompt based on word count.
Parameters:
prompt (str): The text prompt.
Returns:
int: The approximate number of tokens.
"""
words = prompt.split()
# Average tokens per word (can vary based on language and model)
tokens_per_word = 1.3
return int(len(words) * tokens_per_word)
def split_prompt_by_tokens(prompt, token_number):
words = prompt.split()
# Average tokens per word (can vary based on language and model)
tokens_per_word = 1.3
return ' '.join(words[:int(tokens_per_word * token_number)]), ' '.join(words[int(tokens_per_word * token_number):])
# Split prompt precisely by token count
import tiktoken
def split_prompt_precisely(prompt, max_tokens=77, model="gpt-3.5-turbo"):
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
tokens = encoding.encode(prompt)
if len(tokens) <= max_tokens:
return prompt, ""
# Find the split point
split_point = max_tokens
split_tokens = tokens[:split_point]
remaining_tokens = tokens[split_point:]
split_prompt = encoding.decode(split_tokens)
remaining_prompt = encoding.decode(remaining_tokens)
return split_prompt, remaining_prompt
|