Spaces:
Runtime error
Runtime error
File size: 3,486 Bytes
7123498 c01121d 7123498 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import json
import math
import random
import os
import streamlit as st
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("BenBranyon/hotrss-mistral-full")
model = AutoModelForCausalLM.from_pretrained("BenBranyon/hotrss-mistral-full")
st.set_page_config(page_title="House of the Red Solar Sky")
st.markdown(
"""
<style>
#house-of-the-red-solar-sky {
text-align: center;
}
.stApp {
background-image: url('https://f4.bcbits.com/img/a1824579252_16.jpg');
background-repeat: no-repeat;
background-size: cover;
background-blend-mode: hard-light;
}
.st-emotion-cache-1avcm0n {
background: none;
}
.st-emotion-cache-1wmy9hl {
}
.st-emotion-cache-183lzff {
overflow-x: unset;
text-wrap: pretty;
}
</style>
""",
unsafe_allow_html=True,
)
st.title("House of the Red Solar Sky")
st.markdown(
"""
<style>
.aligncenter {
text-align: center;
}
""",
unsafe_allow_html=True,
)
def post_process(output_sequences):
predictions = []
generated_sequences = []
max_repeat = 2
# decode prediction
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
generated_sequence = generated_sequence.tolist()
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True)
generated_sequences.append(text.strip())
for i, g in enumerate(generated_sequences):
res = str(g).replace('\n\n\n', '\n').replace('\n\n', '\n')
lines = res.split('\n')
# print(lines)
# i = max_repeat
# while i != len(lines):
# remove_count = 0
# for index in range(0, max_repeat):
# # print(i - index - 1, i - index)
# if lines[i - index - 1] == lines[i - index]:
# remove_count += 1
# if remove_count == max_repeat:
# lines.pop(i)
# i -= 1
# else:
# i += 1
predictions.append('\n'.join(lines))
return predictions
start = st.text_input("Beginning of the song:", "Rap like a Sasquath in the trees")
if st.button("Run"):
if model is not None:
with st.spinner(text=f"Generating lyrics..."):
encoded_prompt = tokenizer(start, add_special_tokens=False, return_tensors="pt").input_ids
encoded_prompt = encoded_prompt.to(model.device)
# prediction
output_sequences = model.generate(
input_ids=encoded_prompt,
max_length=160,
min_length=100,
temperature=float(1.00),
top_p=float(0.95),
top_k=int(50),
do_sample=True,
repetition_penalty=1.0,
num_return_sequences=1
)
# Post-processing
predictions = post_process(output_sequences)
st.subheader("Results")
for prediction in predictions:
st.text(prediction) |