mabhay's picture
grammar nazi
3f73cc8
from sentence_transformers import SentenceTransformer
import pickle
import numpy as np
import torch
import gradio as gr
import os
os.system("pip install git+https://github.com/openai/whisper.git")
import whisper
infer_model = whisper.load_model("base")
def infer(audio):
result = infer_model.transcribe(audio)
return result["text"]
model = SentenceTransformer("paraphrase-MiniLM-L6-v2")
with open("dep_course_title_to_content_embed.pickle", "rb") as handle:
loaded_map = pickle.load(handle)
dep_name_course_name = list(loaded_map.keys())
deps = list(set([x for (x, y) in dep_name_course_name]))
dep_to_course_name = {}
dep_to_course_embedding = {}
for dep in deps:
dep_to_course_name[dep] = []
dep_to_course_embedding[dep] = []
for (dep_name, course_name), embedding in loaded_map.items():
# print(embedding.shape)
dep_to_course_name[dep_name].append(course_name)
dep_to_course_embedding[dep_name].append(np.array(embedding, dtype=np.float32))
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
def give_best_match(query, audio, Department):
if not Department:
Department = deps
course_titles = []
course_content_embeddings = []
for dep in Department:
course_titles += dep_to_course_name[dep]
course_content_embeddings += dep_to_course_embedding[dep]
course_content_embeddings = np.stack(course_content_embeddings)
if audio:
query = infer(audio)
embed = model.encode(query)
result = cos(torch.from_numpy(course_content_embeddings), torch.from_numpy(embed))
indices = reversed(np.argsort(result))
predictions = {course_titles[i]: float(result[i]) for i in indices}
return query, predictions
demo = gr.Interface(
fn=give_best_match,
inputs=[
gr.Textbox(
label="Describe the course",
lines=5,
placeholder="Type anything related to course/s\n\nTitle, Topics/Sub Topics, Refernce books, Questions asked in exams or some random fun stuff.",
),
gr.Audio(source="microphone", type="filepath", label = "Don't want to type? Try Describing using your sweet voice!!", interactive= True),
gr.CheckboxGroup(deps, label="(Optional) Departments"),
],
outputs=[
gr.Textbox(
label="Query",
lines=2,
),
gr.Label(label="Most Relevant Courses", num_top_classes=5),
],
)
# demo = gr.Interface(
# fn=infer, inputs=gr.Audio(source="microphone", type="filepath"), outputs="text"
# )
demo.launch()