Image-to-music / app.py
Bokanovskii's picture
Update app.py
8b4a267
import gradio as gr
import spotipy
from spotipy import oauth2
from transformers import ViTForImageClassification, ViTImageProcessor
import torch
from torch.nn import functional as F
from torchvision.io import read_image
import tensorflow as tf
from fastapi import FastAPI
from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import HTMLResponse, RedirectResponse
from starlette.requests import Request
import gradio as gr
import uvicorn
from fastapi.responses import HTMLResponse
from fastapi.responses import RedirectResponse
import numpy as np
import base64
from io import BytesIO
from PIL import Image
import time
import shred_model
# Xception fine tuned from pretrained imagenet weights for identifying Sraddha
SRADDHA_MODEL_PATH = "shred_model"
SHRED_MODEL = tf.keras.models.load_model(SRADDHA_MODEL_PATH)
SPOTIPY_TOKEN = None # Set in the homepage function
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Grabbing model")
mood_model = ViTForImageClassification.from_pretrained("jayanta/google-vit-base-patch16-224-cartoon-emotion-detection")
mood_model.eval()
mood_model.to(device)
print("Grabbing feature extractor")
mood_feature_extractor = ViTImageProcessor.from_pretrained("jayanta/google-vit-base-patch16-224-cartoon-emotion-detection")
def main(img, playlist_length, privacy, gen_mode, genre_choice, request: gr.Request):
if img is None:
return None
print("Getting image inference from tansformer")
mood_dict = get_image_mood_dict_from_transformer(img)
print("Getting Sraddha Found Boolean from model")
sraddha_found = get_sraddha(img)
print("Building playlist")
playlist = get_playlist(mood_dict, img, playlist_length, privacy, gen_mode, genre_choice, request)
if playlist is None:
playlist = "Spotipy account token not set"
ret = playlist
if sraddha_found:
valentines_jokes = ["Why shouldn't you trust a pastry chef on Valentine's Day? Because he will dessert you.",
"What do you give your Valentine in France? A big quiche.",
"What did the tortoise say on Valentine's Day? I turt-ally love you.",
"How did the squirrel get his Valentine's attention? He acted like a nut.",
"What do you call sweets that can keep a beat? Candy rappers.",
"What did the paper clip say to the magnet? I find you very attractive.",
"What did the caclulator say to the pencil? You can count on me."]
joke = valentines_jokes[np.random.randint(0, len(valentines_jokes)-1)]
sraddha_msg = """Sraddha, you are the love of my life and seeing you always lifts my spirits. Hopefully these tunes and a joke can do the same for you.
<p>
</p>""" + \
f"<p>{joke}</p><p></p>" + \
"""- With Love, Scoob"""
return gr.update(value=ret, visible=True), gr.update(value=sraddha_msg, visible=True)
return gr.update(value=ret, visible=True), gr.update(visible=False)
def get_image_mood_dict_from_transformer(img):
img = read_image(img)
encoding = mood_feature_extractor(images=img, return_tensors="pt")
pixel_values = encoding['pixel_values'].to(device)
print('Running mood prediction')
outputs = mood_model(pixel_values)
logits = outputs.logits
probabilities = F.softmax(logits, dim = -1).detach().numpy()[0]
mood_dict = dict(zip(mood_model.config.id2label.values(), probabilities))
return mood_dict
def get_sraddha(img):
fixed_img = shred_model.prepare_image(img)
prob = SHRED_MODEL.predict(fixed_img)[0]
if prob >= .5:
return True
def compute_mood(mood_dict):
print(mood_dict)
return mood_dict['happy'] + mood_dict['angry'] * .5 + mood_dict['sad'] * .1
def get_playlist(mood_dict, img, playlist_length, privacy, gen_mode, genre_choice, request: gr.Request):
token = request.request.session.get('token')
genre_map = {'Rock': ['alt-rock', 'alternative', 'indie', 'r-n-b', 'rock'], 'Hip-hop': ['hip-hop'], 'Party': ['house', 'pop', 'party'], 'Mellow': ['blues', 'jazz', 'happy'], 'Indian': ['idm', 'indian'], 'Pop': ['pop', 'new-age'], 'Study': ['study', 'classical', 'jazz', 'happy', 'chill'], 'Romance': ['romance', 'happy', 'pop']}
if token:
mood = compute_mood(mood_dict)
if gen_mode == "By a Chosen Genre":
playlist_name = "Mood " + str(round(mood * 100, 1)) + f": {genre_choice}"
else:
playlist_name = "Mood " + str(round(mood * 100, 1)) + f": {gen_mode}"
sp = spotipy.Spotify(token)
if gen_mode == 'Recently Played':
top_tracks_uri = set([x['track']['uri'] for x in sp.current_user_recently_played(limit=50)['items']])
# I honestly don't know if this errors for people with not enough saved tracks
# Shouldn't be a problem for Sraddha
first_few = [x['track']['uri'] for x in sp.current_user_saved_tracks(limit=50)['items']]
top_tracks_uri.update(first_few)
top_tracks_uri.update([x['track']['uri'] for x in sp.current_user_saved_tracks(limit=50, offset=50)['items']])
top_tracks_uri.update([x['track']['uri'] for x in sp.current_user_saved_tracks(limit=50, offset=100)['items']])
top_tracks_uri.update([x['track']['uri'] for x in sp.current_user_saved_tracks(limit=50, offset=150)['items']])
top_tracks_uri.update([x['uri'] for x in sp.recommendations(seed_tracks=first_few[:5], limit=50)['tracks']])
top_tracks_uri.update([x['uri'] for x in sp.recommendations(seed_tracks=first_few[5:10], limit=50)['tracks']])
top_tracks_uri = list(top_tracks_uri)
elif gen_mode == 'By a Chosen Genre':
genres = genre_map[genre_choice]
final_track_list = [x['uri'] for x in sp.recommendations(
seed_genres=genres, limit=playlist_length, max_valence=mood+.15,
min_valence=mood-.15, min_danceability=mood/1.75, max_danceability=mood*8,
min_energy=mood/2)['tracks']]
else:
top_artists_uri = aggregate_favorite_artists(sp)
top_tracks_uri = aggregate_top_tracks(sp, top_artists_uri)
if gen_mode != 'By a Chosen Genre':
final_track_list = filter_tracks(sp, top_tracks_uri, mood, playlist_length)
# If no tracks fit the filter: generate some results anyways
if len(final_track_list) != playlist_length:
diff = playlist_length - len(final_track_list)
print(f'Filling playlist with {diff} more songs (filter too big)')
seed = [x['track']['uri'] for x in sp.current_user_recently_played(limit=5)['items']]
final_track_list += [x['uri'] for x in sp.recommendations(
seed_tracks=seed, limit=diff,
min_valence=mood-.3, min_energy=mood/3)['tracks']]
iframe_embedding = create_playlist(sp, img, final_track_list, playlist_name,
privacy)
return iframe_embedding
return None
def create_playlist(sp, img, tracks, playlist_name, privacy):
privacy = privacy == "Public"
user_id = sp.current_user()['id']
playlist_description = "This playlist was created using the img-to-music application built by the best boyfriend there ever was and ever will be"
playlist_data = sp.user_playlist_create(user_id, playlist_name, public=privacy,
description=playlist_description)
playlist_id = playlist_data['id']
if len(tracks) == 0:
return """No tracks could be generated from this image"""
sp.user_playlist_add_tracks(user_id, playlist_id, tracks)
def upload_img():
with Image.open(img) as im_file:
im_file.thumbnail((300, 300))
buffered = BytesIO()
im_file.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue())
sp.playlist_upload_cover_image(playlist_id, img_str)
try:
upload_img()
except spotipy.exceptions.SpotifyException as e:
print(f"SpotiftException on image upload: {e}")
print("Retrying")
time.sleep(5)
try:
upload_img()
except Exception as e:
print(e)
except requests.exceptions.ReadTimeout as e:
print(f"Image upload request timeout: {e}")
print("Retrying...")
time.sleep(5)
try:
upload_img()
except Exception as e:
print(e)
time.sleep(3)
iframe_embedding = f"""<iframe style="border-radius:12px" src="https://open.spotify.com/embed/playlist/{playlist_id}" width="100%" height="352" frameBorder="0" allowfullscreen="" allow="autoplay; clipboard-write; encrypted-media; fullscreen; picture-in-picture" loading="lazy"></iframe>"""
return iframe_embedding
def aggregate_favorite_artists(sp):
top_artists_name = set()
top_artists_uri = []
ranges = ['short_term', 'medium_term', 'long_term']
for r in ranges:
top_artists_all_data = sp.current_user_top_artists(limit=50, time_range=r)
top_artists_data = top_artists_all_data['items']
for artist_data in top_artists_data:
if artist_data["name"] not in top_artists_name:
top_artists_name.add(artist_data['name'])
top_artists_uri.append(artist_data['uri'])
followed_artists_all_data = sp.current_user_followed_artists(limit=50)
followed_artsits_data = followed_artists_all_data['artists']
for artist_data in followed_artsits_data['items']:
if artist_data["name"] not in top_artists_name:
top_artists_name.add(artist_data['name'])
top_artists_uri.append(artist_data['uri'])
# attempt to garauntee 200 artists
i = 0
while len(top_artists_uri) < 200:
related_artists_all_data = sp.artist_related_artists(top_artists_uri[i])
i += 1
related_artists_data = related_artists_all_data['artists']
for artist_data in related_artists_data:
if artist_data["name"] not in top_artists_name:
top_artists_name.add(artist_data['name'])
top_artists_uri.append(artist_data['uri'])
if i == len(top_artists_uri):
# could build in a deeper artist recommendation finder here
# would do this if it was going to production but Sraddha follows lots of artists
break
return top_artists_uri
def aggregate_top_tracks(sp, top_artists_uri):
top_tracks_uri = []
for artist in top_artists_uri:
top_tracks_all_data = sp.artist_top_tracks(artist)
top_tracks_data = top_tracks_all_data['tracks']
for track_data in top_tracks_data:
top_tracks_uri.append(track_data['uri'])
return top_tracks_uri
def filter_tracks(sp, top_tracks_uri, mood, playlist_length):
selected_tracks_uri = []
np.random.shuffle(top_tracks_uri)
# Batch network requests
BATCH_SIZE = 100
i = 0
all_track_data = []
while i + BATCH_SIZE < len(top_tracks_uri):
all_track_data += sp.audio_features(top_tracks_uri[i:i+BATCH_SIZE])
i += BATCH_SIZE
all_track_data += sp.audio_features(top_tracks_uri[i:])
for i, track in enumerate(top_tracks_uri):
track_data = all_track_data[i]
if track_data is None:
continue
valence = track_data['valence']
danceability = track_data['danceability']
energy = track_data['energy']
if mood < .1:
if valence <= mood + .15 and \
danceability <= mood * 8 and \
energy <= mood * 10:
selected_tracks_uri.append(track)
elif mood < .25:
if (mood - .1) <= valence <= (mood + .1) and \
danceability <= mood * 4 and \
energy <= mood * 5:
selected_tracks_uri.append(track)
elif mood < .5:
if mood - .05 <= valence <= mood + .05 and \
danceability <= mood * 1.75 and \
energy <= mood * 1.75:
selected_tracks_uri.append(track)
elif mood < .75:
if mood - .1 <= valence <= mood + .1 and \
danceability >= mood / 2.5 and \
energy >= mood / 2:
selected_tracks_uri.append(track)
elif mood < .9:
if mood - .1 <= valence <= mood + .1 and \
danceability >= mood / 2 and \
energy >= mood / 1.75:
selected_tracks_uri.append(track)
else:
if mood - .15 <= valence <= 1 and \
danceability >= mood / 1.75 and \
energy >= mood / 1.5:
selected_tracks_uri.append(track)
if len(selected_tracks_uri) >= playlist_length:
break
return selected_tracks_uri
# Define login and frontend
PORT_NUMBER = 8080
SPOTIPY_CLIENT_ID = '2320153024d042c8ba138a108066246c'
SPOTIPY_CLIENT_SECRET = 'da2746490f6542a3b0cfcff50893e8e8'
#SPOTIPY_REDIRECT_URI = 'http://localhost:7860'
SPOTIPY_REDIRECT_URI = "https://Bokanovskii-Image-to-music.hf.space"
SCOPE = 'ugc-image-upload playlist-read-private playlist-read-collaborative playlist-modify-private playlist-modify-public user-top-read user-read-playback-position user-read-recently-played user-read-email user-follow-read user-library-modify user-library-read user-read-email user-read-private user-read-playback-state user-modify-playback-state user-read-currently-playing app-remote-control streaming'
sp_oauth = oauth2.SpotifyOAuth(SPOTIPY_CLIENT_ID, SPOTIPY_CLIENT_SECRET, SPOTIPY_REDIRECT_URI, scope=SCOPE)
app = FastAPI()
app.add_middleware(SessionMiddleware, secret_key="w.o.w")
@app.get('/', response_class=HTMLResponse)
async def homepage(request: Request):
url = str(request.url)
auth_url = sp_oauth.get_authorize_url()
try:
code = sp_oauth.parse_response_code(url)
if code != url:
request.session['token'] = sp_oauth.get_access_token(code, as_dict=False, check_cache=False)
return RedirectResponse("/gradio")
except:
return """<div style="text-align: center; max-width: 1000px; margin: 0 auto;">
<div
style="
align-items: center;
gap: 0.8rem;
font-size: 1.25rem;
"
>
<h3 style="font-weight: 900; margin-bottom: 30px; margin-top: 20px;">
Image to Music Generator
</h3>\n""" + \
"<p> The server couldn't make a connection with Spotify: please try again </p>\n" + \
f"<a href='" + auth_url + "'>Login to Spotify</a>\n" + \
"""<p>
</p>
<p>
</p>
<small>
Click 'Open in a new window/tab'
<small>
<div
style="
align-items: center;
gap: 0.8rem;
font-size: 1rem;
"
>
<small>
This applet requires a whitelisted Spotify account (contact Charlie Ward)
</small>"""
return """<div style="text-align: center; max-width: 1000px; margin: 0 auto;">
<div
style="
align-items: center;
gap: 0.8rem;
font-size: 1.75rem;
"
>
<h3 style="font-weight: 900; margin-bottom: 30px; margin-top: 20px;">
Image to Music Generator
</h3>\n""" + \
f"<a href='" + auth_url + "'>Login to Spotify</a>\n" + \
"""<p>
</p>
<p>
</p>
<small>
Click 'Open in a new window/tab'
<small>
<div
style="
align-items: center;
gap: 0.8rem;
font-size: 1rem;
"
>
<small>
This applet requires a whitelisted Spotify account (contact Charlie Ward)
</small>"""
with gr.Blocks(css="style.css") as demo:
with gr.Column(elem_id="col-container"):
gr.HTML("""<div style="text-align: center; max-width: 700px; margin: 0 auto;">
<div
style="
display: inline-flex;
align-items: center;
gap: 0.8rem;
font-size: 1.75rem;
"
>
<h1 style="font-weight: 900; margin-bottom: 7px; margin-top: 5px;">
Image to Music Generator
</h1>""")
input_img = gr.Image(type="filepath", elem_id="input-img")
sraddhas_box = gr.HTML(label="Sraddha's Box", elem_id="sraddhas-box", visible=False)
playlist_output = gr.HTML(label="Generated Playlist", elem_id="app-output", visible=True)
with gr.Accordion(label="Playlist Generation Options", open=False):
playlist_length = gr.Slider(minimum=5, maximum=100, value=30, step=5,
label="Playlist Length", elem_id="playlist-length")
with gr.Row():
privacy = gr.Radio(label="Playlist Privacy Level", choices=["Public", "Private"],
value="Private")
gen_mode = gr.Radio(label="Recommendation Base", choices=["Favorites", "Recently Played", "By a Chosen Genre"], value="Favorites")
with gr.Row(visible=False) as genre_choice_row:
genre_choice = gr.Dropdown(label='Choose a Genre', choices=['Rock', 'Pop', 'Hip-hop', 'Party', 'Mellow', 'Indian', 'Study', 'Romance'], value='Pop')
def sraddha_box_hide():
return {sraddhas_box: gr.update(visible=False)}
def genre_dropdown_toggle(gen_mode):
if gen_mode == 'By a Chosen Genre':
return {genre_choice_row: gr.update(visible=True)}
else:
return {genre_choice_row: gr.update(visible=False)}
generate = gr.Button("Generate Playlist from Image")
article = """
<div class="footer">
<p>
Built for Sraddha: playlist generation from image inference
</p>
<p>
Sending Love 🤗
</p>
</div>
"""
gr.HTML(article)
gen_mode.change(genre_dropdown_toggle, inputs=[gen_mode], outputs=[genre_choice_row])
generate.click(sraddha_box_hide, outputs=[sraddhas_box])
generate.click(main, inputs=[input_img, playlist_length, privacy, gen_mode, genre_choice],
outputs=[playlist_output, sraddhas_box], api_name="img-to-music")
gradio_app = gr.mount_gradio_app(app, demo, "/gradio")
uvicorn.run(app, host="0.0.0.0", port=7860)