Spaces:
Runtime error
Runtime error
import gradio as gr | |
import spotipy | |
from spotipy import oauth2 | |
from transformers import ViTForImageClassification, ViTImageProcessor | |
import torch | |
from torch.nn import functional as F | |
from torchvision.io import read_image | |
import tensorflow as tf | |
from fastapi import FastAPI | |
from starlette.middleware.sessions import SessionMiddleware | |
from starlette.responses import HTMLResponse, RedirectResponse | |
from starlette.requests import Request | |
import gradio as gr | |
import uvicorn | |
from fastapi.responses import HTMLResponse | |
from fastapi.responses import RedirectResponse | |
import numpy as np | |
import base64 | |
from io import BytesIO | |
from PIL import Image | |
import time | |
import shred_model | |
# Xception fine tuned from pretrained imagenet weights for identifying Sraddha | |
SRADDHA_MODEL_PATH = "shred_model" | |
SHRED_MODEL = tf.keras.models.load_model(SRADDHA_MODEL_PATH) | |
SPOTIPY_TOKEN = None # Set in the homepage function | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print("Grabbing model") | |
mood_model = ViTForImageClassification.from_pretrained("jayanta/google-vit-base-patch16-224-cartoon-emotion-detection") | |
mood_model.eval() | |
mood_model.to(device) | |
print("Grabbing feature extractor") | |
mood_feature_extractor = ViTImageProcessor.from_pretrained("jayanta/google-vit-base-patch16-224-cartoon-emotion-detection") | |
def main(img, playlist_length, privacy, gen_mode, genre_choice, request: gr.Request): | |
if img is None: | |
return None | |
print("Getting image inference from tansformer") | |
mood_dict = get_image_mood_dict_from_transformer(img) | |
print("Getting Sraddha Found Boolean from model") | |
sraddha_found = get_sraddha(img) | |
print("Building playlist") | |
playlist = get_playlist(mood_dict, img, playlist_length, privacy, gen_mode, genre_choice, request) | |
if playlist is None: | |
playlist = "Spotipy account token not set" | |
ret = playlist | |
if sraddha_found: | |
valentines_jokes = ["Why shouldn't you trust a pastry chef on Valentine's Day? Because he will dessert you.", | |
"What do you give your Valentine in France? A big quiche.", | |
"What did the tortoise say on Valentine's Day? I turt-ally love you.", | |
"How did the squirrel get his Valentine's attention? He acted like a nut.", | |
"What do you call sweets that can keep a beat? Candy rappers.", | |
"What did the paper clip say to the magnet? I find you very attractive.", | |
"What did the caclulator say to the pencil? You can count on me."] | |
joke = valentines_jokes[np.random.randint(0, len(valentines_jokes)-1)] | |
sraddha_msg = """Sraddha, you are the love of my life and seeing you always lifts my spirits. Hopefully these tunes and a joke can do the same for you. | |
<p> | |
</p>""" + \ | |
f"<p>{joke}</p><p></p>" + \ | |
"""- With Love, Scoob""" | |
return gr.update(value=ret, visible=True), gr.update(value=sraddha_msg, visible=True) | |
return gr.update(value=ret, visible=True), gr.update(visible=False) | |
def get_image_mood_dict_from_transformer(img): | |
img = read_image(img) | |
encoding = mood_feature_extractor(images=img, return_tensors="pt") | |
pixel_values = encoding['pixel_values'].to(device) | |
print('Running mood prediction') | |
outputs = mood_model(pixel_values) | |
logits = outputs.logits | |
probabilities = F.softmax(logits, dim = -1).detach().numpy()[0] | |
mood_dict = dict(zip(mood_model.config.id2label.values(), probabilities)) | |
return mood_dict | |
def get_sraddha(img): | |
fixed_img = shred_model.prepare_image(img) | |
prob = SHRED_MODEL.predict(fixed_img)[0] | |
if prob >= .5: | |
return True | |
def compute_mood(mood_dict): | |
print(mood_dict) | |
return mood_dict['happy'] + mood_dict['angry'] * .5 + mood_dict['sad'] * .1 | |
def get_playlist(mood_dict, img, playlist_length, privacy, gen_mode, genre_choice, request: gr.Request): | |
token = request.request.session.get('token') | |
genre_map = {'Rock': ['alt-rock', 'alternative', 'indie', 'r-n-b', 'rock'], 'Hip-hop': ['hip-hop'], 'Party': ['house', 'pop', 'party'], 'Mellow': ['blues', 'jazz', 'happy'], 'Indian': ['idm', 'indian'], 'Pop': ['pop', 'new-age'], 'Study': ['study', 'classical', 'jazz', 'happy', 'chill'], 'Romance': ['romance', 'happy', 'pop']} | |
if token: | |
mood = compute_mood(mood_dict) | |
if gen_mode == "By a Chosen Genre": | |
playlist_name = "Mood " + str(round(mood * 100, 1)) + f": {genre_choice}" | |
else: | |
playlist_name = "Mood " + str(round(mood * 100, 1)) + f": {gen_mode}" | |
sp = spotipy.Spotify(token) | |
if gen_mode == 'Recently Played': | |
top_tracks_uri = set([x['track']['uri'] for x in sp.current_user_recently_played(limit=50)['items']]) | |
# I honestly don't know if this errors for people with not enough saved tracks | |
# Shouldn't be a problem for Sraddha | |
first_few = [x['track']['uri'] for x in sp.current_user_saved_tracks(limit=50)['items']] | |
top_tracks_uri.update(first_few) | |
top_tracks_uri.update([x['track']['uri'] for x in sp.current_user_saved_tracks(limit=50, offset=50)['items']]) | |
top_tracks_uri.update([x['track']['uri'] for x in sp.current_user_saved_tracks(limit=50, offset=100)['items']]) | |
top_tracks_uri.update([x['track']['uri'] for x in sp.current_user_saved_tracks(limit=50, offset=150)['items']]) | |
top_tracks_uri.update([x['uri'] for x in sp.recommendations(seed_tracks=first_few[:5], limit=50)['tracks']]) | |
top_tracks_uri.update([x['uri'] for x in sp.recommendations(seed_tracks=first_few[5:10], limit=50)['tracks']]) | |
top_tracks_uri = list(top_tracks_uri) | |
elif gen_mode == 'By a Chosen Genre': | |
genres = genre_map[genre_choice] | |
final_track_list = [x['uri'] for x in sp.recommendations( | |
seed_genres=genres, limit=playlist_length, max_valence=mood+.15, | |
min_valence=mood-.15, min_danceability=mood/1.75, max_danceability=mood*8, | |
min_energy=mood/2)['tracks']] | |
else: | |
top_artists_uri = aggregate_favorite_artists(sp) | |
top_tracks_uri = aggregate_top_tracks(sp, top_artists_uri) | |
if gen_mode != 'By a Chosen Genre': | |
final_track_list = filter_tracks(sp, top_tracks_uri, mood, playlist_length) | |
# If no tracks fit the filter: generate some results anyways | |
if len(final_track_list) != playlist_length: | |
diff = playlist_length - len(final_track_list) | |
print(f'Filling playlist with {diff} more songs (filter too big)') | |
seed = [x['track']['uri'] for x in sp.current_user_recently_played(limit=5)['items']] | |
final_track_list += [x['uri'] for x in sp.recommendations( | |
seed_tracks=seed, limit=diff, | |
min_valence=mood-.3, min_energy=mood/3)['tracks']] | |
iframe_embedding = create_playlist(sp, img, final_track_list, playlist_name, | |
privacy) | |
return iframe_embedding | |
return None | |
def create_playlist(sp, img, tracks, playlist_name, privacy): | |
privacy = privacy == "Public" | |
user_id = sp.current_user()['id'] | |
playlist_description = "This playlist was created using the img-to-music application built by the best boyfriend there ever was and ever will be" | |
playlist_data = sp.user_playlist_create(user_id, playlist_name, public=privacy, | |
description=playlist_description) | |
playlist_id = playlist_data['id'] | |
if len(tracks) == 0: | |
return """No tracks could be generated from this image""" | |
sp.user_playlist_add_tracks(user_id, playlist_id, tracks) | |
def upload_img(): | |
with Image.open(img) as im_file: | |
im_file.thumbnail((300, 300)) | |
buffered = BytesIO() | |
im_file.save(buffered, format="JPEG") | |
img_str = base64.b64encode(buffered.getvalue()) | |
sp.playlist_upload_cover_image(playlist_id, img_str) | |
try: | |
upload_img() | |
except spotipy.exceptions.SpotifyException as e: | |
print(f"SpotiftException on image upload: {e}") | |
print("Retrying") | |
time.sleep(5) | |
try: | |
upload_img() | |
except Exception as e: | |
print(e) | |
except requests.exceptions.ReadTimeout as e: | |
print(f"Image upload request timeout: {e}") | |
print("Retrying...") | |
time.sleep(5) | |
try: | |
upload_img() | |
except Exception as e: | |
print(e) | |
time.sleep(3) | |
iframe_embedding = f"""<iframe style="border-radius:12px" src="https://open.spotify.com/embed/playlist/{playlist_id}" width="100%" height="352" frameBorder="0" allowfullscreen="" allow="autoplay; clipboard-write; encrypted-media; fullscreen; picture-in-picture" loading="lazy"></iframe>""" | |
return iframe_embedding | |
def aggregate_favorite_artists(sp): | |
top_artists_name = set() | |
top_artists_uri = [] | |
ranges = ['short_term', 'medium_term', 'long_term'] | |
for r in ranges: | |
top_artists_all_data = sp.current_user_top_artists(limit=50, time_range=r) | |
top_artists_data = top_artists_all_data['items'] | |
for artist_data in top_artists_data: | |
if artist_data["name"] not in top_artists_name: | |
top_artists_name.add(artist_data['name']) | |
top_artists_uri.append(artist_data['uri']) | |
followed_artists_all_data = sp.current_user_followed_artists(limit=50) | |
followed_artsits_data = followed_artists_all_data['artists'] | |
for artist_data in followed_artsits_data['items']: | |
if artist_data["name"] not in top_artists_name: | |
top_artists_name.add(artist_data['name']) | |
top_artists_uri.append(artist_data['uri']) | |
# attempt to garauntee 200 artists | |
i = 0 | |
while len(top_artists_uri) < 200: | |
related_artists_all_data = sp.artist_related_artists(top_artists_uri[i]) | |
i += 1 | |
related_artists_data = related_artists_all_data['artists'] | |
for artist_data in related_artists_data: | |
if artist_data["name"] not in top_artists_name: | |
top_artists_name.add(artist_data['name']) | |
top_artists_uri.append(artist_data['uri']) | |
if i == len(top_artists_uri): | |
# could build in a deeper artist recommendation finder here | |
# would do this if it was going to production but Sraddha follows lots of artists | |
break | |
return top_artists_uri | |
def aggregate_top_tracks(sp, top_artists_uri): | |
top_tracks_uri = [] | |
for artist in top_artists_uri: | |
top_tracks_all_data = sp.artist_top_tracks(artist) | |
top_tracks_data = top_tracks_all_data['tracks'] | |
for track_data in top_tracks_data: | |
top_tracks_uri.append(track_data['uri']) | |
return top_tracks_uri | |
def filter_tracks(sp, top_tracks_uri, mood, playlist_length): | |
selected_tracks_uri = [] | |
np.random.shuffle(top_tracks_uri) | |
# Batch network requests | |
BATCH_SIZE = 100 | |
i = 0 | |
all_track_data = [] | |
while i + BATCH_SIZE < len(top_tracks_uri): | |
all_track_data += sp.audio_features(top_tracks_uri[i:i+BATCH_SIZE]) | |
i += BATCH_SIZE | |
all_track_data += sp.audio_features(top_tracks_uri[i:]) | |
for i, track in enumerate(top_tracks_uri): | |
track_data = all_track_data[i] | |
if track_data is None: | |
continue | |
valence = track_data['valence'] | |
danceability = track_data['danceability'] | |
energy = track_data['energy'] | |
if mood < .1: | |
if valence <= mood + .15 and \ | |
danceability <= mood * 8 and \ | |
energy <= mood * 10: | |
selected_tracks_uri.append(track) | |
elif mood < .25: | |
if (mood - .1) <= valence <= (mood + .1) and \ | |
danceability <= mood * 4 and \ | |
energy <= mood * 5: | |
selected_tracks_uri.append(track) | |
elif mood < .5: | |
if mood - .05 <= valence <= mood + .05 and \ | |
danceability <= mood * 1.75 and \ | |
energy <= mood * 1.75: | |
selected_tracks_uri.append(track) | |
elif mood < .75: | |
if mood - .1 <= valence <= mood + .1 and \ | |
danceability >= mood / 2.5 and \ | |
energy >= mood / 2: | |
selected_tracks_uri.append(track) | |
elif mood < .9: | |
if mood - .1 <= valence <= mood + .1 and \ | |
danceability >= mood / 2 and \ | |
energy >= mood / 1.75: | |
selected_tracks_uri.append(track) | |
else: | |
if mood - .15 <= valence <= 1 and \ | |
danceability >= mood / 1.75 and \ | |
energy >= mood / 1.5: | |
selected_tracks_uri.append(track) | |
if len(selected_tracks_uri) >= playlist_length: | |
break | |
return selected_tracks_uri | |
# Define login and frontend | |
PORT_NUMBER = 8080 | |
SPOTIPY_CLIENT_ID = '2320153024d042c8ba138a108066246c' | |
SPOTIPY_CLIENT_SECRET = 'da2746490f6542a3b0cfcff50893e8e8' | |
#SPOTIPY_REDIRECT_URI = 'http://localhost:7860' | |
SPOTIPY_REDIRECT_URI = "https://Bokanovskii-Image-to-music.hf.space" | |
SCOPE = 'ugc-image-upload playlist-read-private playlist-read-collaborative playlist-modify-private playlist-modify-public user-top-read user-read-playback-position user-read-recently-played user-read-email user-follow-read user-library-modify user-library-read user-read-email user-read-private user-read-playback-state user-modify-playback-state user-read-currently-playing app-remote-control streaming' | |
sp_oauth = oauth2.SpotifyOAuth(SPOTIPY_CLIENT_ID, SPOTIPY_CLIENT_SECRET, SPOTIPY_REDIRECT_URI, scope=SCOPE) | |
app = FastAPI() | |
app.add_middleware(SessionMiddleware, secret_key="w.o.w") | |
async def homepage(request: Request): | |
url = str(request.url) | |
auth_url = sp_oauth.get_authorize_url() | |
try: | |
code = sp_oauth.parse_response_code(url) | |
if code != url: | |
request.session['token'] = sp_oauth.get_access_token(code, as_dict=False, check_cache=False) | |
return RedirectResponse("/gradio") | |
except: | |
return """<div style="text-align: center; max-width: 1000px; margin: 0 auto;"> | |
<div | |
style=" | |
align-items: center; | |
gap: 0.8rem; | |
font-size: 1.25rem; | |
" | |
> | |
<h3 style="font-weight: 900; margin-bottom: 30px; margin-top: 20px;"> | |
Image to Music Generator | |
</h3>\n""" + \ | |
"<p> The server couldn't make a connection with Spotify: please try again </p>\n" + \ | |
f"<a href='" + auth_url + "'>Login to Spotify</a>\n" + \ | |
"""<p> | |
</p> | |
<p> | |
</p> | |
<small> | |
Click 'Open in a new window/tab' | |
<small> | |
<div | |
style=" | |
align-items: center; | |
gap: 0.8rem; | |
font-size: 1rem; | |
" | |
> | |
<small> | |
This applet requires a whitelisted Spotify account (contact Charlie Ward) | |
</small>""" | |
return """<div style="text-align: center; max-width: 1000px; margin: 0 auto;"> | |
<div | |
style=" | |
align-items: center; | |
gap: 0.8rem; | |
font-size: 1.75rem; | |
" | |
> | |
<h3 style="font-weight: 900; margin-bottom: 30px; margin-top: 20px;"> | |
Image to Music Generator | |
</h3>\n""" + \ | |
f"<a href='" + auth_url + "'>Login to Spotify</a>\n" + \ | |
"""<p> | |
</p> | |
<p> | |
</p> | |
<small> | |
Click 'Open in a new window/tab' | |
<small> | |
<div | |
style=" | |
align-items: center; | |
gap: 0.8rem; | |
font-size: 1rem; | |
" | |
> | |
<small> | |
This applet requires a whitelisted Spotify account (contact Charlie Ward) | |
</small>""" | |
with gr.Blocks(css="style.css") as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.HTML("""<div style="text-align: center; max-width: 700px; margin: 0 auto;"> | |
<div | |
style=" | |
display: inline-flex; | |
align-items: center; | |
gap: 0.8rem; | |
font-size: 1.75rem; | |
" | |
> | |
<h1 style="font-weight: 900; margin-bottom: 7px; margin-top: 5px;"> | |
Image to Music Generator | |
</h1>""") | |
input_img = gr.Image(type="filepath", elem_id="input-img") | |
sraddhas_box = gr.HTML(label="Sraddha's Box", elem_id="sraddhas-box", visible=False) | |
playlist_output = gr.HTML(label="Generated Playlist", elem_id="app-output", visible=True) | |
with gr.Accordion(label="Playlist Generation Options", open=False): | |
playlist_length = gr.Slider(minimum=5, maximum=100, value=30, step=5, | |
label="Playlist Length", elem_id="playlist-length") | |
with gr.Row(): | |
privacy = gr.Radio(label="Playlist Privacy Level", choices=["Public", "Private"], | |
value="Private") | |
gen_mode = gr.Radio(label="Recommendation Base", choices=["Favorites", "Recently Played", "By a Chosen Genre"], value="Favorites") | |
with gr.Row(visible=False) as genre_choice_row: | |
genre_choice = gr.Dropdown(label='Choose a Genre', choices=['Rock', 'Pop', 'Hip-hop', 'Party', 'Mellow', 'Indian', 'Study', 'Romance'], value='Pop') | |
def sraddha_box_hide(): | |
return {sraddhas_box: gr.update(visible=False)} | |
def genre_dropdown_toggle(gen_mode): | |
if gen_mode == 'By a Chosen Genre': | |
return {genre_choice_row: gr.update(visible=True)} | |
else: | |
return {genre_choice_row: gr.update(visible=False)} | |
generate = gr.Button("Generate Playlist from Image") | |
article = """ | |
<div class="footer"> | |
<p> | |
Built for Sraddha: playlist generation from image inference | |
</p> | |
<p> | |
Sending Love 🤗 | |
</p> | |
</div> | |
""" | |
gr.HTML(article) | |
gen_mode.change(genre_dropdown_toggle, inputs=[gen_mode], outputs=[genre_choice_row]) | |
generate.click(sraddha_box_hide, outputs=[sraddhas_box]) | |
generate.click(main, inputs=[input_img, playlist_length, privacy, gen_mode, genre_choice], | |
outputs=[playlist_output, sraddhas_box], api_name="img-to-music") | |
gradio_app = gr.mount_gradio_app(app, demo, "/gradio") | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |