import gradio as gr import spotipy from spotipy import oauth2 from transformers import ViTForImageClassification, ViTImageProcessor import torch from torch.nn import functional as F from torchvision.io import read_image import tensorflow as tf from fastapi import FastAPI from starlette.middleware.sessions import SessionMiddleware from starlette.responses import HTMLResponse, RedirectResponse from starlette.requests import Request import gradio as gr import uvicorn from fastapi.responses import HTMLResponse from fastapi.responses import RedirectResponse import numpy as np import base64 from io import BytesIO from PIL import Image import time import shred_model # Xception fine tuned from pretrained imagenet weights for identifying Sraddha SRADDHA_MODEL_PATH = "shred_model" SHRED_MODEL = tf.keras.models.load_model(SRADDHA_MODEL_PATH) SPOTIPY_TOKEN = None # Set in the homepage function device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print("Grabbing model") mood_model = ViTForImageClassification.from_pretrained("jayanta/google-vit-base-patch16-224-cartoon-emotion-detection") mood_model.eval() mood_model.to(device) print("Grabbing feature extractor") mood_feature_extractor = ViTImageProcessor.from_pretrained("jayanta/google-vit-base-patch16-224-cartoon-emotion-detection") def main(img, playlist_length, privacy, gen_mode, genre_choice, request: gr.Request): if img is None: return None print("Getting image inference from tansformer") mood_dict = get_image_mood_dict_from_transformer(img) print("Getting Sraddha Found Boolean from model") sraddha_found = get_sraddha(img) print("Building playlist") playlist = get_playlist(mood_dict, img, playlist_length, privacy, gen_mode, genre_choice, request) if playlist is None: playlist = "Spotipy account token not set" ret = playlist if sraddha_found: valentines_jokes = ["Why shouldn't you trust a pastry chef on Valentine's Day? Because he will dessert you.", "What do you give your Valentine in France? A big quiche.", "What did the tortoise say on Valentine's Day? I turt-ally love you.", "How did the squirrel get his Valentine's attention? He acted like a nut.", "What do you call sweets that can keep a beat? Candy rappers.", "What did the paper clip say to the magnet? I find you very attractive.", "What did the caclulator say to the pencil? You can count on me."] joke = valentines_jokes[np.random.randint(0, len(valentines_jokes)-1)] sraddha_msg = """Sraddha, you are the love of my life and seeing you always lifts my spirits. Hopefully these tunes and a joke can do the same for you.

""" + \ f"

{joke}

" + \ """- With Love, Scoob""" return gr.update(value=ret, visible=True), gr.update(value=sraddha_msg, visible=True) return gr.update(value=ret, visible=True), gr.update(visible=False) def get_image_mood_dict_from_transformer(img): img = read_image(img) encoding = mood_feature_extractor(images=img, return_tensors="pt") pixel_values = encoding['pixel_values'].to(device) print('Running mood prediction') outputs = mood_model(pixel_values) logits = outputs.logits probabilities = F.softmax(logits, dim = -1).detach().numpy()[0] mood_dict = dict(zip(mood_model.config.id2label.values(), probabilities)) return mood_dict def get_sraddha(img): fixed_img = shred_model.prepare_image(img) prob = SHRED_MODEL.predict(fixed_img)[0] if prob >= .5: return True def compute_mood(mood_dict): print(mood_dict) return mood_dict['happy'] + mood_dict['angry'] * .5 + mood_dict['sad'] * .1 def get_playlist(mood_dict, img, playlist_length, privacy, gen_mode, genre_choice, request: gr.Request): token = request.request.session.get('token') genre_map = {'Rock': ['alt-rock', 'alternative', 'indie', 'r-n-b', 'rock'], 'Hip-hop': ['hip-hop'], 'Party': ['house', 'pop', 'party'], 'Mellow': ['blues', 'jazz', 'happy'], 'Indian': ['idm', 'indian'], 'Pop': ['pop', 'new-age'], 'Study': ['study', 'classical', 'jazz', 'happy', 'chill'], 'Romance': ['romance', 'happy', 'pop']} if token: mood = compute_mood(mood_dict) if gen_mode == "By a Chosen Genre": playlist_name = "Mood " + str(round(mood * 100, 1)) + f": {genre_choice}" else: playlist_name = "Mood " + str(round(mood * 100, 1)) + f": {gen_mode}" sp = spotipy.Spotify(token) if gen_mode == 'Recently Played': top_tracks_uri = set([x['track']['uri'] for x in sp.current_user_recently_played(limit=50)['items']]) # I honestly don't know if this errors for people with not enough saved tracks # Shouldn't be a problem for Sraddha first_few = [x['track']['uri'] for x in sp.current_user_saved_tracks(limit=50)['items']] top_tracks_uri.update(first_few) top_tracks_uri.update([x['track']['uri'] for x in sp.current_user_saved_tracks(limit=50, offset=50)['items']]) top_tracks_uri.update([x['track']['uri'] for x in sp.current_user_saved_tracks(limit=50, offset=100)['items']]) top_tracks_uri.update([x['track']['uri'] for x in sp.current_user_saved_tracks(limit=50, offset=150)['items']]) top_tracks_uri.update([x['uri'] for x in sp.recommendations(seed_tracks=first_few[:5], limit=50)['tracks']]) top_tracks_uri.update([x['uri'] for x in sp.recommendations(seed_tracks=first_few[5:10], limit=50)['tracks']]) top_tracks_uri = list(top_tracks_uri) elif gen_mode == 'By a Chosen Genre': genres = genre_map[genre_choice] final_track_list = [x['uri'] for x in sp.recommendations( seed_genres=genres, limit=playlist_length, max_valence=mood+.15, min_valence=mood-.15, min_danceability=mood/1.75, max_danceability=mood*8, min_energy=mood/2)['tracks']] else: top_artists_uri = aggregate_favorite_artists(sp) top_tracks_uri = aggregate_top_tracks(sp, top_artists_uri) if gen_mode != 'By a Chosen Genre': final_track_list = filter_tracks(sp, top_tracks_uri, mood, playlist_length) # If no tracks fit the filter: generate some results anyways if len(final_track_list) != playlist_length: diff = playlist_length - len(final_track_list) print(f'Filling playlist with {diff} more songs (filter too big)') seed = [x['track']['uri'] for x in sp.current_user_recently_played(limit=5)['items']] final_track_list += [x['uri'] for x in sp.recommendations( seed_tracks=seed, limit=diff, min_valence=mood-.3, min_energy=mood/3)['tracks']] iframe_embedding = create_playlist(sp, img, final_track_list, playlist_name, privacy) return iframe_embedding return None def create_playlist(sp, img, tracks, playlist_name, privacy): privacy = privacy == "Public" user_id = sp.current_user()['id'] playlist_description = "This playlist was created using the img-to-music application built by the best boyfriend there ever was and ever will be" playlist_data = sp.user_playlist_create(user_id, playlist_name, public=privacy, description=playlist_description) playlist_id = playlist_data['id'] if len(tracks) == 0: return """No tracks could be generated from this image""" sp.user_playlist_add_tracks(user_id, playlist_id, tracks) def upload_img(): with Image.open(img) as im_file: im_file.thumbnail((300, 300)) buffered = BytesIO() im_file.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()) sp.playlist_upload_cover_image(playlist_id, img_str) try: upload_img() except spotipy.exceptions.SpotifyException as e: print(f"SpotiftException on image upload: {e}") print("Retrying") time.sleep(5) try: upload_img() except Exception as e: print(e) except requests.exceptions.ReadTimeout as e: print(f"Image upload request timeout: {e}") print("Retrying...") time.sleep(5) try: upload_img() except Exception as e: print(e) time.sleep(3) iframe_embedding = f"""""" return iframe_embedding def aggregate_favorite_artists(sp): top_artists_name = set() top_artists_uri = [] ranges = ['short_term', 'medium_term', 'long_term'] for r in ranges: top_artists_all_data = sp.current_user_top_artists(limit=50, time_range=r) top_artists_data = top_artists_all_data['items'] for artist_data in top_artists_data: if artist_data["name"] not in top_artists_name: top_artists_name.add(artist_data['name']) top_artists_uri.append(artist_data['uri']) followed_artists_all_data = sp.current_user_followed_artists(limit=50) followed_artsits_data = followed_artists_all_data['artists'] for artist_data in followed_artsits_data['items']: if artist_data["name"] not in top_artists_name: top_artists_name.add(artist_data['name']) top_artists_uri.append(artist_data['uri']) # attempt to garauntee 200 artists i = 0 while len(top_artists_uri) < 200: related_artists_all_data = sp.artist_related_artists(top_artists_uri[i]) i += 1 related_artists_data = related_artists_all_data['artists'] for artist_data in related_artists_data: if artist_data["name"] not in top_artists_name: top_artists_name.add(artist_data['name']) top_artists_uri.append(artist_data['uri']) if i == len(top_artists_uri): # could build in a deeper artist recommendation finder here # would do this if it was going to production but Sraddha follows lots of artists break return top_artists_uri def aggregate_top_tracks(sp, top_artists_uri): top_tracks_uri = [] for artist in top_artists_uri: top_tracks_all_data = sp.artist_top_tracks(artist) top_tracks_data = top_tracks_all_data['tracks'] for track_data in top_tracks_data: top_tracks_uri.append(track_data['uri']) return top_tracks_uri def filter_tracks(sp, top_tracks_uri, mood, playlist_length): selected_tracks_uri = [] np.random.shuffle(top_tracks_uri) # Batch network requests BATCH_SIZE = 100 i = 0 all_track_data = [] while i + BATCH_SIZE < len(top_tracks_uri): all_track_data += sp.audio_features(top_tracks_uri[i:i+BATCH_SIZE]) i += BATCH_SIZE all_track_data += sp.audio_features(top_tracks_uri[i:]) for i, track in enumerate(top_tracks_uri): track_data = all_track_data[i] if track_data is None: continue valence = track_data['valence'] danceability = track_data['danceability'] energy = track_data['energy'] if mood < .1: if valence <= mood + .15 and \ danceability <= mood * 8 and \ energy <= mood * 10: selected_tracks_uri.append(track) elif mood < .25: if (mood - .1) <= valence <= (mood + .1) and \ danceability <= mood * 4 and \ energy <= mood * 5: selected_tracks_uri.append(track) elif mood < .5: if mood - .05 <= valence <= mood + .05 and \ danceability <= mood * 1.75 and \ energy <= mood * 1.75: selected_tracks_uri.append(track) elif mood < .75: if mood - .1 <= valence <= mood + .1 and \ danceability >= mood / 2.5 and \ energy >= mood / 2: selected_tracks_uri.append(track) elif mood < .9: if mood - .1 <= valence <= mood + .1 and \ danceability >= mood / 2 and \ energy >= mood / 1.75: selected_tracks_uri.append(track) else: if mood - .15 <= valence <= 1 and \ danceability >= mood / 1.75 and \ energy >= mood / 1.5: selected_tracks_uri.append(track) if len(selected_tracks_uri) >= playlist_length: break return selected_tracks_uri # Define login and frontend PORT_NUMBER = 8080 SPOTIPY_CLIENT_ID = '2320153024d042c8ba138a108066246c' SPOTIPY_CLIENT_SECRET = 'da2746490f6542a3b0cfcff50893e8e8' #SPOTIPY_REDIRECT_URI = 'http://localhost:7860' SPOTIPY_REDIRECT_URI = "https://Bokanovskii-Image-to-music.hf.space" SCOPE = 'ugc-image-upload playlist-read-private playlist-read-collaborative playlist-modify-private playlist-modify-public user-top-read user-read-playback-position user-read-recently-played user-read-email user-follow-read user-library-modify user-library-read user-read-email user-read-private user-read-playback-state user-modify-playback-state user-read-currently-playing app-remote-control streaming' sp_oauth = oauth2.SpotifyOAuth(SPOTIPY_CLIENT_ID, SPOTIPY_CLIENT_SECRET, SPOTIPY_REDIRECT_URI, scope=SCOPE) app = FastAPI() app.add_middleware(SessionMiddleware, secret_key="w.o.w") @app.get('/', response_class=HTMLResponse) async def homepage(request: Request): url = str(request.url) auth_url = sp_oauth.get_authorize_url() try: code = sp_oauth.parse_response_code(url) if code != url: request.session['token'] = sp_oauth.get_access_token(code, as_dict=False, check_cache=False) return RedirectResponse("/gradio") except: return """

Image to Music Generator

\n""" + \ "

The server couldn't make a connection with Spotify: please try again

\n" + \ f"Login to Spotify\n" + \ """

Click 'Open in a new window/tab'
This applet requires a whitelisted Spotify account (contact Charlie Ward) """ return """

Image to Music Generator

\n""" + \ f"Login to Spotify\n" + \ """

Click 'Open in a new window/tab'
This applet requires a whitelisted Spotify account (contact Charlie Ward) """ with gr.Blocks(css="style.css") as demo: with gr.Column(elem_id="col-container"): gr.HTML("""

Image to Music Generator

""") input_img = gr.Image(type="filepath", elem_id="input-img") sraddhas_box = gr.HTML(label="Sraddha's Box", elem_id="sraddhas-box", visible=False) playlist_output = gr.HTML(label="Generated Playlist", elem_id="app-output", visible=True) with gr.Accordion(label="Playlist Generation Options", open=False): playlist_length = gr.Slider(minimum=5, maximum=100, value=30, step=5, label="Playlist Length", elem_id="playlist-length") with gr.Row(): privacy = gr.Radio(label="Playlist Privacy Level", choices=["Public", "Private"], value="Private") gen_mode = gr.Radio(label="Recommendation Base", choices=["Favorites", "Recently Played", "By a Chosen Genre"], value="Favorites") with gr.Row(visible=False) as genre_choice_row: genre_choice = gr.Dropdown(label='Choose a Genre', choices=['Rock', 'Pop', 'Hip-hop', 'Party', 'Mellow', 'Indian', 'Study', 'Romance'], value='Pop') def sraddha_box_hide(): return {sraddhas_box: gr.update(visible=False)} def genre_dropdown_toggle(gen_mode): if gen_mode == 'By a Chosen Genre': return {genre_choice_row: gr.update(visible=True)} else: return {genre_choice_row: gr.update(visible=False)} generate = gr.Button("Generate Playlist from Image") article = """ """ gr.HTML(article) gen_mode.change(genre_dropdown_toggle, inputs=[gen_mode], outputs=[genre_choice_row]) generate.click(sraddha_box_hide, outputs=[sraddhas_box]) generate.click(main, inputs=[input_img, playlist_length, privacy, gen_mode, genre_choice], outputs=[playlist_output, sraddhas_box], api_name="img-to-music") gradio_app = gr.mount_gradio_app(app, demo, "/gradio") uvicorn.run(app, host="0.0.0.0", port=7860)