Norgan97 commited on
Commit
e838bd7
1 Parent(s): 7be954b
Files changed (2) hide show
  1. app.py +190 -0
  2. data/countries.csv +0 -0
app.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import torch
4
+ from geopy.geocoders import ArcGIS
5
+ import folium
6
+ from streamlit_folium import folium_static
7
+ from transformers import AutoTokenizer, AutoModel
8
+ import numpy as np
9
+ from sklearn.metrics.pairwise import cosine_similarity
10
+
11
+ session_state = st.session_state
12
+ if not hasattr(session_state, 'recommended_countries'):
13
+ session_state.recommended_countries = []
14
+
15
+ st.set_page_config(layout="wide")
16
+
17
+ @st.cache_resource()
18
+ def load_model():
19
+ model = AutoModel.from_pretrained("cointegrated/rubert-tiny2")
20
+ tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
21
+ return model, tokenizer
22
+
23
+ model, tokenizer = load_model()
24
+
25
+ @st.cache_data()
26
+ def load_data():
27
+ df = pd.read_csv('data/countries.csv')
28
+ return df
29
+
30
+ df = load_data()
31
+
32
+ def embed_bert_cls(text, model, tokenizer):
33
+ t = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
34
+ with torch.no_grad():
35
+ model_output = model(**{k: v.to(model.device) for k, v in t.items()})
36
+ embeddings = model_output.last_hidden_state[:, 0, :]
37
+ embeddings = torch.nn.functional.normalize(embeddings)
38
+ return embeddings[0].cpu().numpy()
39
+
40
+ def get_coordinates(country_name):
41
+ geolocator = ArcGIS()
42
+ location = geolocator.geocode(country_name)
43
+ if location:
44
+ return location.latitude, location.longitude
45
+ else:
46
+ return None
47
+
48
+ st.markdown("<h1 style='text-align: center;'>блаблабла</h1>", unsafe_allow_html=True)
49
+
50
+
51
+ st.markdown("<style> input {font-size: 25px !important;}</style>", unsafe_allow_html=True)
52
+
53
+ first_input = st.text_input('Введите предпочтения по климату и типу местности')
54
+ second_input = st.text_input('Введите предпочтения по еде')
55
+ third_input = st.text_input('Введите предпочтения по активностям')
56
+ option = st.selectbox(
57
+ 'Виза',
58
+ ('Да', 'Нет')
59
+ )
60
+ sec_option = st.selectbox(
61
+ 'Местоположение',
62
+ ('Африка','Азия','Европа','Океания','Северная Америка','Южная Америка')
63
+ )
64
+ third_option = st.slider('Выберите значение, характеризующее оценку безопасности страны', 1.0, 3.6, 1.7, 0.1)
65
+
66
+ col3,col4 = st.columns([1,1])
67
+ col5,col6, col7 = st.columns([5,5,5])
68
+
69
+ with col3:
70
+ button_test = st.button('Получить рекомендацию')
71
+
72
+ if button_test and first_input and second_input and third_input :
73
+ filtered_df = df[df['visa'] == option]
74
+ filtered_df = filtered_df[filtered_df['location'] == sec_option]
75
+ filtered_df = filtered_df[filtered_df['peace_index'] <= third_option]
76
+
77
+ decode_first = embed_bert_cls(first_input, model, tokenizer)
78
+ decode_second = embed_bert_cls(second_input, model, tokenizer)
79
+ decode_third = embed_bert_cls(third_input, model, tokenizer)
80
+ try:
81
+ review_embeddings = np.vstack(filtered_df['embeddings_review'].apply(lambda x: np.fromstring(x[1:-1], sep=' ')))
82
+ kitchen_embeddings = np.vstack(filtered_df['embeddings_kitchen'].apply(lambda x: np.fromstring(x[1:-1], sep=' ')))
83
+ activity_embeddings = np.vstack(filtered_df['embeddings_activity'].apply(lambda x: np.fromstring(x[1:-1], sep=' ')))
84
+ similarity_col1 = cosine_similarity(decode_first.reshape(1, -1), review_embeddings)
85
+ similarity_col2 = cosine_similarity(decode_second.reshape(1, -1), kitchen_embeddings)
86
+ similarity_col3 = cosine_similarity(decode_third.reshape(1, -1), activity_embeddings)
87
+ mean_similarity = np.mean([similarity_col1, similarity_col2, similarity_col3], axis=0)
88
+ max_similarity_row = np.argmax(mean_similarity)
89
+ max_similarity_value = np.max(mean_similarity)
90
+
91
+ recommended_country = filtered_df.iloc[max_similarity_row]['country']
92
+ recommended_review = filtered_df.iloc[max_similarity_row]['short_review']
93
+ recommended_flag = filtered_df.iloc[max_similarity_row]['flag']
94
+ recommended_photo = filtered_df.iloc[max_similarity_row]['country_photo']
95
+ similarity_values = [similarity_col1[:, max_similarity_row],
96
+ similarity_col2[:, max_similarity_row],
97
+ similarity_col3[:, max_similarity_row]]
98
+
99
+ session_state.recommended_countries.append(recommended_country)
100
+ with col5:
101
+ st.image(recommended_photo, width=795, use_column_width=False)
102
+ with col6:
103
+ st.image(recommended_flag, width=200, use_column_width=False)
104
+ st.markdown(f"<p style='font-size: 25px;'>Рекомендуемая страна: {recommended_country}</p>", unsafe_allow_html=True)
105
+ st.markdown(f"<p style='font-size: 25px;'> {recommended_review}</p>", unsafe_allow_html=True)
106
+ scale_html = f'<div style="width: 300px; height: 30px;">'
107
+ scale_html += f'<progress value="{max_similarity_value}" max="1" style="width: 100%; height: 100%;"></progress>'
108
+ scale_html += f'<div style="position: relative; top: -22px; text-align: center;">'
109
+ scale_html += f'<span style="position: absolute; left: 0;">0</span>'
110
+ scale_html += f'<span style="position: absolute; right: 0;">1</span>'
111
+ scale_html += f'</div></div>'
112
+ st.markdown(f"<p style='font-size: 25px;'>Оценка близости вашего запроса и страны</p>", unsafe_allow_html=True)
113
+ st.markdown(scale_html, unsafe_allow_html=True)
114
+ with col7:
115
+ # st.write('Местоположение на карте мира')
116
+ coordinates = get_coordinates(recommended_country)
117
+ if coordinates:
118
+ my_map = folium.Map(location=coordinates, zoom_start=5, tiles="Cartodb Positron",max_bounds=True, min_lon=-180, max_lon=180, min_lat=-90, max_lat=90,min_zoom=2,max_zoom=15)
119
+
120
+ folium.Marker(location=coordinates, popup=recommended_country).add_to(my_map)
121
+ folium_static(my_map)
122
+ else:
123
+ st.write(f"Координаты для страны {recommended_country} не найдены.")
124
+ except ValueError as e:
125
+ st.write('Нет такой страны')
126
+ if session_state.recommended_countries:
127
+ with col4:
128
+
129
+ next_button = st.button("Следующая рекомендация")
130
+ if next_button:
131
+ filtered_df = df[df['visa'] == option]
132
+ filtered_df = filtered_df[filtered_df['location'] == sec_option]
133
+ filtered_df = filtered_df[~filtered_df['country'].isin(session_state.recommended_countries)]
134
+ decode_first = embed_bert_cls(first_input, model, tokenizer)
135
+ decode_second = embed_bert_cls(second_input, model, tokenizer)
136
+ decode_third = embed_bert_cls(third_input, model, tokenizer)
137
+
138
+ review_embeddings = np.vstack(filtered_df['embeddings_review'].apply(lambda x: np.fromstring(x[1:-1], sep=' '))) if not filtered_df.empty else None
139
+ kitchen_embeddings = np.vstack(filtered_df['embeddings_kitchen'].apply(lambda x: np.fromstring(x[1:-1], sep=' '))) if not filtered_df.empty else None
140
+ activity_embeddings = np.vstack(filtered_df['embeddings_activity'].apply(lambda x: np.fromstring(x[1:-1], sep=' '))) if not filtered_df.empty else None
141
+
142
+ if review_embeddings is not None and kitchen_embeddings is not None and activity_embeddings is not None:
143
+ similarity_col1 = cosine_similarity(decode_first.reshape(1, -1), review_embeddings)
144
+ similarity_col2 = cosine_similarity(decode_second.reshape(1, -1), kitchen_embeddings)
145
+ similarity_col3 = cosine_similarity(decode_third.reshape(1, -1), activity_embeddings)
146
+ mean_similarity = np.mean([similarity_col1, similarity_col2, similarity_col3], axis=0)
147
+ max_similarity_row = np.argmax(mean_similarity)
148
+ max_similarity_value = np.max(mean_similarity)
149
+
150
+ if max_similarity_value > 0:
151
+
152
+ recommended_country = filtered_df.iloc[max_similarity_row]['country']
153
+ recommended_review = filtered_df.iloc[max_similarity_row]['short_review']
154
+ recommended_flag = filtered_df.iloc[max_similarity_row]['flag']
155
+ recommended_photo = filtered_df.iloc[max_similarity_row]['country_photo']
156
+ similarity_values = [similarity_col1[:, max_similarity_row],
157
+ similarity_col2[:, max_similarity_row],
158
+ similarity_col3[:, max_similarity_row]]
159
+
160
+
161
+ session_state.recommended_countries.append(recommended_country)
162
+ with col5:
163
+ st.image(recommended_photo, width=795, use_column_width=False)
164
+ with col6:
165
+ st.image(recommended_flag, width=200, use_column_width=False)
166
+ st.markdown(f"<p style='font-size: 25px;'>Рекомендуемая страна: {recommended_country}</p>", unsafe_allow_html=True)
167
+ st.markdown(f"<p style='font-size: 25px;'> {recommended_review}</p>", unsafe_allow_html=True)
168
+ scale_html = f'<div style="width: 300px; height: 30px;">'
169
+ scale_html += f'<progress value="{max_similarity_value}" max="1" style="width: 100%; height: 100%;"></progress>'
170
+ scale_html += f'<div style="position: relative; top: -22px; text-align: center;">'
171
+ scale_html += f'<span style="position: absolute; left: 0;">0</span>'
172
+ scale_html += f'<span style="position: absolute; right: 0;">1</span>'
173
+ scale_html += f'</div></div>'
174
+ st.markdown(f"<p style='font-size: 25px;'>Оценка близости вашего запроса и страны</p>", unsafe_allow_html=True)
175
+ st.markdown(scale_html, unsafe_allow_html=True)
176
+ with col7:
177
+ # st.write('Местоположение на карте мира')
178
+ coordinates = get_coordinates(recommended_country)
179
+ if coordinates:
180
+ my_map = folium.Map(location=coordinates, zoom_start=5, tiles="Cartodb Positron",
181
+ max_bounds=True,
182
+ min_lon=-180, max_lon=180, min_lat=-90, max_lat=90, min_zoom=2,
183
+ max_zoom=15)
184
+
185
+ folium.Marker(location=coordinates, popup=recommended_country).add_to(my_map)
186
+ folium_static(my_map)
187
+ else:
188
+ st.write(f"Координаты для страны {recommended_country} не найдены.")
189
+ else:
190
+ st.write("Больше рекомендаций нет.")
data/countries.csv ADDED
The diff for this file is too large to render. See raw diff