File size: 6,381 Bytes
ac6138f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ea3806
ac6138f
 
 
 
 
0ea3806
ac6138f
 
 
 
 
 
 
8e68708
ac6138f
afff708
8e68708
 
 
 
afff708
 
 
 
 
 
 
 
ac6138f
 
 
 
8e68708
 
 
ac6138f
 
 
 
afff708
 
 
ac6138f
 
 
 
 
 
 
 
 
500010a
ac6138f
500010a
 
afff708
ac6138f
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import json

import gradio as gr

from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')


def get_n_weighted_scores(embeddings, query, n, objective_weight, subjective_weight):
    query = [model.encode(query)]

    weighted_scores = []

    for key, value in embeddings.items():
        objective_embedding = value['objective_embedding']
        subjective_embeddings = value['subjective_embeddings']
        
        objective_score = cosine_similarity(query, objective_embedding).item()
        subjective_scores = cosine_similarity(query, subjective_embeddings)

        max_score = 0
        max_review_index = 0
        for idx, score in enumerate(subjective_scores[0].tolist()):
            weighted_score = ((objective_score * objective_weight)+(score * subjective_weight))
            if weighted_score > max_score:
                max_score = weighted_score
                max_review_index = idx
        
        weighted_scores.append((key, max_score, max_review_index))
    
    return sorted(weighted_scores, key=lambda x: x[1], reverse=True)[:n]

def filter_anime(embeddings, genres, themes, rating):
    genres = set(genres)
    themes = set(themes)
    rating = set(rating)

    filtered_anime = embeddings.copy()
    for key, anime in embeddings.items():

        anime_genres = set(anime['genres'])
        anime_themes = set(anime['themes'])
        anime_rating = set([anime['rating']])

        if genres.intersection(anime_genres) or 'ALL' in genres:
            pass
        else:
            filtered_anime.pop(key)
            continue
        if themes.intersection(anime_themes) or 'ALL' in themes:
            pass
        else:
            filtered_anime.pop(key)
            continue
        if rating.intersection(anime_rating) or 'ALL' in rating:
            pass
        else:
            filtered_anime.pop(key)
            continue
        
    return filtered_anime

def get_recommendation(query, number_of_recommendations, genres, themes, rating, objective_weight, subjective_weight):
    filtered_anime = filter_anime(embeddings, genres, themes, rating)
    results = []
    weighted_scores = get_n_weighted_scores(filtered_anime, query, number_of_recommendations, float(objective_weight), float(subjective_weight))
    for idx, (key, score, review_index) in enumerate(weighted_scores, start=1):
        data = embeddings[key]
        english = data['english']
        description = data['description']
        review = data['reviews'][review_index]['text']
        image = data['image']

        results.append(gr.Image(label=f"{english}",value=image, height=435, width=500, visible=True))
        results.append(gr.Textbox(label=f"Recommendation {idx}: {english}", value=description, max_lines=7, visible=True))
        results.append(gr.Textbox(label=f"Most Relevant User Review",value=review, max_lines=7, visible=True))

    for i in range(3*((15*3)-(3*number_of_recommendations))):
        results.append("N/A")
    
    return results
    
if __name__ == '__main__':
    
    with open('./embeddings/data.json') as f:
        data = json.load(f)
        embeddings = data['embeddings']
        filters = data['filters']

    with gr.Blocks(theme=gr.themes.Soft(primary_hue='red')) as demo:
        with gr.Row():
            with gr.Column():
                gr.Markdown(
                    '''
                    # Welcome to the Nuanced Recommendation System!
                    ### This system **combines** both objective (synopsis, episode count, themes) and subjective (user reviews) data, in order to recommend the most approprate anime. Feel free to refine using the **optional** filters below! 
                    '''
                )
            with gr.Column():
                pass
            

        with gr.Row():
            with gr.Column() as input_col:
                query = gr.Textbox(label="What are you looking for?")
                number_of_recommendations = gr.Slider(label= "# of Recommendations", minimum=1, maximum=10, value=3, step=1)
                genres = gr.Dropdown(label='Genres',multiselect=True,choices=filters['genres'], value=['ALL'])
                themes = gr.Dropdown(label='Themes',multiselect=True,choices=filters['themes'], value=['ALL'])
                rating = gr.Dropdown(label='Rating',multiselect=True,choices=filters['rating'], value=['ALL'])
                objective_weight = gr.Slider(label= "Objective Weight", minimum=0, maximum=1, value=.5, step=.1)
                subjective_weight = gr.Slider(label= "Subjective Weight", minimum=0, maximum=1, value=.5, step=.1)
                submit_btn = gr.Button("Submit")

                examples = gr.Examples(
                    examples=[
                        ['A sci-fi anime set in a future where AI and robots have become self-aware', 3, ['Action', 'Sci-Fi', 'Fantasy'], ['ALL'], ['PG-13 - Teens 13 or older'], .8, .2],
                        ['An anime where a group of students form a band, and the story focuses on their personal growth and struggles with adulthood', 5, ['ALL'], ['Music'], ['PG-13 - Teens 13 or older', 'R - 17+ (violence & profanity)'], .3, .7],
                        ['An anime where the main character starts as a villain but slowly redeems themselves', 3, ['Suspense', 'Action'], ['ALL'], ['PG-13 - Teens 13 or older', 'R - 17+ (violence & profanity)'], .2, .8],
                    ],
                    inputs=[query, number_of_recommendations, genres, themes, rating, objective_weight, subjective_weight],
                )

            outputs = []
            with gr.Column():
                for i in range(15):
                    with gr.Row():
                        with gr.Column():
                            outputs.append(gr.Image(height=435, width=500, visible=False))
                        with gr.Column():
                            outputs.append(gr.Textbox(max_lines=7, visible=False))
                            outputs.append(gr.Textbox(max_lines=7, visible=False))
                            

        submit_btn.click(
            get_recommendation,
            [query, number_of_recommendations, genres, themes, rating, objective_weight, subjective_weight],
            outputs
        )

    demo.launch()