File size: 6,381 Bytes
ac6138f 0ea3806 ac6138f 0ea3806 ac6138f 8e68708 ac6138f afff708 8e68708 afff708 ac6138f 8e68708 ac6138f afff708 ac6138f 500010a ac6138f 500010a afff708 ac6138f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import os
import json
import gradio as gr
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
def get_n_weighted_scores(embeddings, query, n, objective_weight, subjective_weight):
query = [model.encode(query)]
weighted_scores = []
for key, value in embeddings.items():
objective_embedding = value['objective_embedding']
subjective_embeddings = value['subjective_embeddings']
objective_score = cosine_similarity(query, objective_embedding).item()
subjective_scores = cosine_similarity(query, subjective_embeddings)
max_score = 0
max_review_index = 0
for idx, score in enumerate(subjective_scores[0].tolist()):
weighted_score = ((objective_score * objective_weight)+(score * subjective_weight))
if weighted_score > max_score:
max_score = weighted_score
max_review_index = idx
weighted_scores.append((key, max_score, max_review_index))
return sorted(weighted_scores, key=lambda x: x[1], reverse=True)[:n]
def filter_anime(embeddings, genres, themes, rating):
genres = set(genres)
themes = set(themes)
rating = set(rating)
filtered_anime = embeddings.copy()
for key, anime in embeddings.items():
anime_genres = set(anime['genres'])
anime_themes = set(anime['themes'])
anime_rating = set([anime['rating']])
if genres.intersection(anime_genres) or 'ALL' in genres:
pass
else:
filtered_anime.pop(key)
continue
if themes.intersection(anime_themes) or 'ALL' in themes:
pass
else:
filtered_anime.pop(key)
continue
if rating.intersection(anime_rating) or 'ALL' in rating:
pass
else:
filtered_anime.pop(key)
continue
return filtered_anime
def get_recommendation(query, number_of_recommendations, genres, themes, rating, objective_weight, subjective_weight):
filtered_anime = filter_anime(embeddings, genres, themes, rating)
results = []
weighted_scores = get_n_weighted_scores(filtered_anime, query, number_of_recommendations, float(objective_weight), float(subjective_weight))
for idx, (key, score, review_index) in enumerate(weighted_scores, start=1):
data = embeddings[key]
english = data['english']
description = data['description']
review = data['reviews'][review_index]['text']
image = data['image']
results.append(gr.Image(label=f"{english}",value=image, height=435, width=500, visible=True))
results.append(gr.Textbox(label=f"Recommendation {idx}: {english}", value=description, max_lines=7, visible=True))
results.append(gr.Textbox(label=f"Most Relevant User Review",value=review, max_lines=7, visible=True))
for i in range(3*((15*3)-(3*number_of_recommendations))):
results.append("N/A")
return results
if __name__ == '__main__':
with open('./embeddings/data.json') as f:
data = json.load(f)
embeddings = data['embeddings']
filters = data['filters']
with gr.Blocks(theme=gr.themes.Soft(primary_hue='red')) as demo:
with gr.Row():
with gr.Column():
gr.Markdown(
'''
# Welcome to the Nuanced Recommendation System!
### This system **combines** both objective (synopsis, episode count, themes) and subjective (user reviews) data, in order to recommend the most approprate anime. Feel free to refine using the **optional** filters below!
'''
)
with gr.Column():
pass
with gr.Row():
with gr.Column() as input_col:
query = gr.Textbox(label="What are you looking for?")
number_of_recommendations = gr.Slider(label= "# of Recommendations", minimum=1, maximum=10, value=3, step=1)
genres = gr.Dropdown(label='Genres',multiselect=True,choices=filters['genres'], value=['ALL'])
themes = gr.Dropdown(label='Themes',multiselect=True,choices=filters['themes'], value=['ALL'])
rating = gr.Dropdown(label='Rating',multiselect=True,choices=filters['rating'], value=['ALL'])
objective_weight = gr.Slider(label= "Objective Weight", minimum=0, maximum=1, value=.5, step=.1)
subjective_weight = gr.Slider(label= "Subjective Weight", minimum=0, maximum=1, value=.5, step=.1)
submit_btn = gr.Button("Submit")
examples = gr.Examples(
examples=[
['A sci-fi anime set in a future where AI and robots have become self-aware', 3, ['Action', 'Sci-Fi', 'Fantasy'], ['ALL'], ['PG-13 - Teens 13 or older'], .8, .2],
['An anime where a group of students form a band, and the story focuses on their personal growth and struggles with adulthood', 5, ['ALL'], ['Music'], ['PG-13 - Teens 13 or older', 'R - 17+ (violence & profanity)'], .3, .7],
['An anime where the main character starts as a villain but slowly redeems themselves', 3, ['Suspense', 'Action'], ['ALL'], ['PG-13 - Teens 13 or older', 'R - 17+ (violence & profanity)'], .2, .8],
],
inputs=[query, number_of_recommendations, genres, themes, rating, objective_weight, subjective_weight],
)
outputs = []
with gr.Column():
for i in range(15):
with gr.Row():
with gr.Column():
outputs.append(gr.Image(height=435, width=500, visible=False))
with gr.Column():
outputs.append(gr.Textbox(max_lines=7, visible=False))
outputs.append(gr.Textbox(max_lines=7, visible=False))
submit_btn.click(
get_recommendation,
[query, number_of_recommendations, genres, themes, rating, objective_weight, subjective_weight],
outputs
)
demo.launch()
|