dejanseo commited on
Commit
eafe7c3
·
verified ·
1 Parent(s): fffb0cd

Upload 2 files

Browse files
Files changed (2) hide show
  1. app/app.py +140 -0
  2. app/requirements.txt +6 -0
app/app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import AlbertTokenizer, AlbertForSequenceClassification
4
+ import plotly.graph_objects as go
5
+
6
+ # URL of the logo
7
+ logo_url = "https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png"
8
+
9
+ # Display the logo at the top using st.logo
10
+ st.logo(logo_url, link="https://dejan.ai")
11
+
12
+ # Streamlit app title and description
13
+ st.title("Search Query Form Classifier")
14
+ st.write(
15
+ "Ambiguous search queries are candidates for query expansion. Our model identifies such queries with an 80 percent accuracy and is deployed in a batch processing pipeline directly connected with Google Search Console API. In this demo you can test the model capability by testing individual queries."
16
+ )
17
+ st.write("Enter a query to check if it's well-formed:")
18
+
19
+ # Load the model and tokenizer from the Hugging Face Model Hub
20
+ model_name = 'dejanseo/Query-Quality-Classifier'
21
+ tokenizer = AlbertTokenizer.from_pretrained(model_name)
22
+ model = AlbertForSequenceClassification.from_pretrained(model_name)
23
+
24
+ # Set the model to evaluation mode
25
+ model.eval()
26
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
27
+ model.to(device)
28
+
29
+ # Create tabs for single and bulk queries
30
+ tab1, tab2 = st.tabs(["Single Query", "Bulk Query"])
31
+
32
+ with tab1:
33
+ user_input = st.text_input("Query:", "where can I book cheap flights to london")
34
+ #st.write("Developed by [Dejan AI](https://dejan.ai/blog/search-query-quality-classifier/)")
35
+
36
+ def classify_query(query):
37
+ # Tokenize input
38
+ inputs = tokenizer.encode_plus(
39
+ query,
40
+ add_special_tokens=True,
41
+ max_length=32,
42
+ padding='max_length',
43
+ truncation=True,
44
+ return_attention_mask=True,
45
+ return_tensors='pt'
46
+ )
47
+
48
+ input_ids = inputs['input_ids'].to(device)
49
+ attention_mask = inputs['attention_mask'].to(device)
50
+
51
+ # Perform inference
52
+ with torch.no_grad():
53
+ outputs = model(input_ids, attention_mask=attention_mask)
54
+ logits = outputs.logits
55
+ softmax_scores = torch.softmax(logits, dim=1).cpu().numpy()[0]
56
+ confidence = softmax_scores[1] * 100 # Confidence for well-formed class
57
+
58
+ return confidence
59
+
60
+ # Function to determine color based on confidence
61
+ def get_color(confidence):
62
+ if confidence < 50:
63
+ return 'rgba(255, 51, 0, 0.8)' # Red
64
+ else:
65
+ return 'rgba(57, 172, 57, 0.8)' # Green
66
+
67
+ # Check and display classification for single query
68
+ if user_input:
69
+ confidence = classify_query(user_input)
70
+
71
+ # Plotly grey placeholder bar with dynamic color fill
72
+ fig = go.Figure()
73
+
74
+ # Placeholder grey bar
75
+ fig.add_trace(go.Bar(
76
+ x=[100],
77
+ y=['Well-formedness Factor'],
78
+ orientation='h',
79
+ marker=dict(
80
+ color='lightgrey'
81
+ ),
82
+ width=0.8
83
+ ))
84
+
85
+ # Colored bar based on confidence
86
+ fig.add_trace(go.Bar(
87
+ x=[confidence],
88
+ y=['Well-formedness Factor'],
89
+ orientation='h',
90
+ marker=dict(
91
+ color=get_color(confidence)
92
+ ),
93
+ width=0.8
94
+ ))
95
+
96
+ fig.update_layout(
97
+ xaxis=dict(range=[0, 100], title='Well-formedness Factor'),
98
+ yaxis=dict(showticklabels=False),
99
+ width=600,
100
+ height=250, # Increase height for better visibility
101
+ title_text='Well-formedness Factor',
102
+ plot_bgcolor='rgba(0,0,0,0)',
103
+ showlegend=False
104
+ )
105
+
106
+ st.plotly_chart(fig)
107
+
108
+ if confidence >= 50:
109
+ st.success(f"Query Score: {confidence:.2f}% Most likely doesn't require query expansion.")
110
+ st.subheader(f":sparkles: What's next?", divider="gray")
111
+ st.write("Connect with Google Search Console, Semrush, Ahrefs or any other search query source API and detect all queries which could benefit from expansion.")
112
+ st.write("[Engage our team](https://dejan.ai/call/) if you'd like us to do this for you.")
113
+ else:
114
+ st.error(f"The query is likely not well-formed with a score of {100 - confidence:.2f}% and most likely requires query expansion.")
115
+ st.subheader(f":sparkles: What's next?", divider="gray")
116
+ st.write("Connect with Google Search Console, Semrush, Ahrefs or any other search query source API and detect all queries which could benefit from expansion.")
117
+ st.write("[Engage our team](https://dejan.ai/call/) if you'd like us to do this for you.")
118
+
119
+ with tab2:
120
+ st.write("Paste multiple queries line-separated (no headers or extra data):")
121
+ bulk_input = st.text_area("Bulk Queries:", height=200)
122
+
123
+ if bulk_input:
124
+ bulk_queries = bulk_input.splitlines()
125
+ st.write("Processing queries...")
126
+
127
+ # Classify each query in bulk input
128
+ results = [(query, classify_query(query)) for query in bulk_queries]
129
+
130
+ # Display results in a table
131
+ for query, confidence in results:
132
+ st.write(f"Query: {query} - Score: {confidence:.2f}%")
133
+ if confidence >= 50:
134
+ st.success("Well-formed")
135
+ else:
136
+ st.error("Not well-formed")
137
+
138
+ st.subheader(f":sparkles: What's next?", divider="gray")
139
+ st.write("Connect with Google Search Console, Semrush, Ahrefs or any other search query source API and detect all queries which could benefit from expansion.")
140
+ st.write("[Engage our team](https://dejan.ai/call/) if you'd like us to do this for you.")
app/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ transformers
4
+ datasets
5
+ plotly
6
+ sentencepiece