Sengil's picture
Update README.md
06de1ff verified
metadata
library_name: transformers
base_model: google/gemma-2b-it
datasets:
  - Sengil/stable-diffusion-prompts-keywords
language:
  - en
tags:
  - stable-diffusion
  - prompt-generation
  - transformers
  - text-to-prompt
  - Text-to-Text
model_name: gemma-StableDiffusion-prompt-generator-v1
license: apache-2.0

Model Card for Model ID

The gemma-StableDiffusion-prompt-generator-v1 model is designed to enhance the process of generating prompts for Stable Diffusion models. Stable Diffusion models are widely used for creating high-quality images based on textual descriptions. This model specifically focuses on transforming a given word or sentence into a well-crafted and effective prompt, which can then be used to guide Stable Diffusion models in producing visually appealing and contextually accurate images. By utilizing this prompt generator, users can improve the quality of their image generation tasks, making the entire process more efficient and effective.

Details

Model Description

  • Developed by: [Mert Sengil]
  • Model type: [Text-to-Text]
  • Language(s) (NLP): [English]
  • License: [MIT License]
  • Finetuned from model: [google/gemma-2b-it]

Uses

Below we share some code snippets on how to get quickly started with running the model. First make sure to

!pip install -U transformers

then copy the snippet from the section that is relevant for your usecase.

Using steps:

from transformers import AutoTokenizer, AutoModelForCausalLM

# upload finetuned model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Sengil/gemma-StableDiffusion-prompt-generator-v1")
model = AutoModelForCausalLM.from_pretrained("Sengil/gemma-StableDiffusion-prompt-generator-v1")

import torch
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()  # Ensure model is in evaluation mode

def get_completion(query: str, model, tokenizer) -> str:

    # Updated prompt template to encourage generation
    prompt_template = "Generate a creative and descriptive stable diffusion prompt for the following query: {}"
    prompt = prompt_template.format(query)

    encodeds = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
    model_inputs = encodeds.to(device)

    # Ensure model is in float32 precision
    model = model.float()

    # Additional generation parameters
    generation_params = {
        "max_length": model_inputs['input_ids'].shape[1] + 250,
        "do_sample": True,
        "temperature": 0.9,
        "top_k": 50,
        "top_p": 0.95,
        "pad_token_id": tokenizer.eos_token_id,
        "num_return_sequences": 1  # Adjust if you want multiple outputs
    }

    with torch.no_grad():  # Disable gradient calculation for inference
        generated_ids = model.generate(**model_inputs, **generation_params)

    decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

    # Extract the generated prompt from the output
    prompt_start = "Generate a creative and descriptive stable diffusion prompt for the following query: "
    if prompt_start in decoded:
        generated_prompt = decoded.split(prompt_start, 1)[1].strip()
    else:
        generated_prompt = decoded.strip()

    return generated_prompt

#enter your query
query = 'Colorful Cat'
result = get_completion(query, model, tokenizer)
print(f"Prompt:\n{result}")

Direct Use

from transformers import AutoTokenizer, AutoModelForCausalLM

# upload finetuned model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Sengil/gemma-StableDiffusion-prompt-generator-v1")
model = AutoModelForCausalLM.from_pretrained("Sengil/gemma-StableDiffusion-prompt-generator-v1")

Inputs and Outputs

Input:

A single word or a sentence that serves as a keyword or a phrase.

Output:

A generated prompt in English, optimized for use with Stable Diffusion. The prompt will be coherent and contextually relevant to the input provided.

Framework versions

  • PEFT 0.11.2.dev0