Spaces:
Runtime error
Runtime error
import json | |
import math | |
import random | |
import os | |
import streamlit as st | |
import transformers | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
tokenizer = AutoTokenizer.from_pretrained("BenBranyon/hotrss-mistral-full") | |
model = AutoModelForCausalLM.from_pretrained("BenBranyon/hotrss-mistral-full") | |
st.set_page_config(page_title="House of the Red Solar Sky") | |
st.markdown( | |
""" | |
<style> | |
#house-of-the-red-solar-sky { | |
text-align: center; | |
} | |
.stApp { | |
background-image: url('https://f4.bcbits.com/img/a1824579252_16.jpg'); | |
background-repeat: no-repeat; | |
background-size: cover; | |
background-blend-mode: hard-light; | |
} | |
.st-emotion-cache-1avcm0n { | |
background: none; | |
} | |
.st-emotion-cache-1wmy9hl { | |
} | |
.st-emotion-cache-183lzff { | |
overflow-x: unset; | |
text-wrap: pretty; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
st.title("House of the Red Solar Sky") | |
st.markdown( | |
""" | |
<style> | |
.aligncenter { | |
text-align: center; | |
} | |
""", | |
unsafe_allow_html=True, | |
) | |
def post_process(output_sequences): | |
predictions = [] | |
generated_sequences = [] | |
max_repeat = 2 | |
# decode prediction | |
for generated_sequence_idx, generated_sequence in enumerate(output_sequences): | |
generated_sequence = generated_sequence.tolist() | |
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True) | |
generated_sequences.append(text.strip()) | |
for i, g in enumerate(generated_sequences): | |
res = str(g).replace('\n\n\n', '\n').replace('\n\n', '\n') | |
lines = res.split('\n') | |
# print(lines) | |
# i = max_repeat | |
# while i != len(lines): | |
# remove_count = 0 | |
# for index in range(0, max_repeat): | |
# # print(i - index - 1, i - index) | |
# if lines[i - index - 1] == lines[i - index]: | |
# remove_count += 1 | |
# if remove_count == max_repeat: | |
# lines.pop(i) | |
# i -= 1 | |
# else: | |
# i += 1 | |
predictions.append('\n'.join(lines)) | |
return predictions | |
start = st.text_input("Beginning of the song:", "Rap like a Sasquath in the trees") | |
if st.button("Run"): | |
if model is not None: | |
with st.spinner(text=f"Generating lyrics..."): | |
encoded_prompt = tokenizer(start, add_special_tokens=False, return_tensors="pt").input_ids | |
encoded_prompt = encoded_prompt.to(model.device) | |
# prediction | |
output_sequences = model.generate( | |
input_ids=encoded_prompt, | |
max_length=160, | |
min_length=100, | |
temperature=float(1.00), | |
top_p=float(0.95), | |
top_k=int(50), | |
do_sample=True, | |
repetition_penalty=1.0, | |
num_return_sequences=1 | |
) | |
# Post-processing | |
predictions = post_process(output_sequences) | |
st.subheader("Results") | |
for prediction in predictions: | |
st.text(prediction) |