File size: 3,486 Bytes
7123498
 
 
 
 
 
 
 
c01121d
 
7123498
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import json
import math
import random
import os
import streamlit as st
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("BenBranyon/hotrss-mistral-full")
model = AutoModelForCausalLM.from_pretrained("BenBranyon/hotrss-mistral-full")

st.set_page_config(page_title="House of the Red Solar Sky")

st.markdown(
            """
        <style>
        #house-of-the-red-solar-sky {
          text-align: center;
        }
        .stApp {
          background-image: url('https://f4.bcbits.com/img/a1824579252_16.jpg');
          background-repeat: no-repeat;
          background-size: cover;
          background-blend-mode: hard-light;
        }
        .st-emotion-cache-1avcm0n {
          background: none;
        }
        .st-emotion-cache-1wmy9hl {
        }
        .st-emotion-cache-183lzff {
          overflow-x: unset;
          text-wrap: pretty;
        }
        </style>
            """,
            unsafe_allow_html=True,
)

st.title("House of the Red Solar Sky")

st.markdown(
            """
        <style>
        .aligncenter {
          text-align: center;
        }
            """,
            unsafe_allow_html=True,
)

def post_process(output_sequences):
    predictions = []
    generated_sequences = []

    max_repeat = 2

    # decode prediction
    for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
        generated_sequence = generated_sequence.tolist()
        text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True)
        generated_sequences.append(text.strip())
                    
    for i, g in enumerate(generated_sequences):
        res = str(g).replace('\n\n\n', '\n').replace('\n\n', '\n')
        lines = res.split('\n')
        # print(lines)
        # i = max_repeat
        # while i != len(lines):
        #   remove_count = 0
        #   for index in range(0, max_repeat):
        #     # print(i - index - 1, i - index)
        #     if lines[i - index - 1] == lines[i - index]:
        #       remove_count += 1
        #   if remove_count == max_repeat:
        #     lines.pop(i)
        #     i -= 1
        #   else:
        #     i += 1
        predictions.append('\n'.join(lines))

    return predictions

start = st.text_input("Beginning of the song:", "Rap like a Sasquath in the trees")

if st.button("Run"):
    if model is not None:
        with st.spinner(text=f"Generating lyrics..."):
            encoded_prompt = tokenizer(start, add_special_tokens=False, return_tensors="pt").input_ids
            encoded_prompt = encoded_prompt.to(model.device)
            # prediction
            output_sequences = model.generate(
                                    input_ids=encoded_prompt,
                                    max_length=160,
                                    min_length=100,
                                    temperature=float(1.00),
                                    top_p=float(0.95),
                                    top_k=int(50),
                                    do_sample=True,
                                    repetition_penalty=1.0,
                                    num_return_sequences=1
                                    )
            # Post-processing
            predictions = post_process(output_sequences)
            st.subheader("Results")
            for prediction in predictions:
                st.text(prediction)