Spaces:
Running
on
T4
Running
on
T4
File size: 4,578 Bytes
70ec0d7 59b8055 70ec0d7 59b8055 70ec0d7 59b8055 70ec0d7 59b8055 70ec0d7 59b8055 70ec0d7 59b8055 70ec0d7 59b8055 15763b2 70ec0d7 59b8055 70ec0d7 59b8055 70ec0d7 59b8055 70ec0d7 59b8055 70ec0d7 59b8055 70ec0d7 59b8055 70ec0d7 59b8055 70ec0d7 59b8055 70ec0d7 59b8055 70ec0d7 ea11adf 70ec0d7 59b8055 70ec0d7 59b8055 70ec0d7 59b8055 70ec0d7 59b8055 70ec0d7 15763b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import gradio as gr
import nltk
import numpy as np
import pandas as pd
from librosa import load, resample
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline
# Constants
filename = "df10k_SP500_2020.csv.zip"
model_name = "sentence-transformers/msmarco-distilbert-base-v4"
max_sequence_length = 512
embeddings_filename = "df10k_embeddings_msmarco-distilbert-base-v4.npz"
asr_model = "facebook/wav2vec2-xls-r-300m-21-to-en"
# Load corpus
df = pd.read_csv(filename)
df.drop_duplicates(inplace=True)
print(f"Number of documents: {len(df)}")
nltk.download("punkt")
corpus = []
sentence_count = []
for _, row in df.iterrows():
# We're interested in the 'mdna' column: 'Management discussion and analysis'
sentences = nltk.tokenize.sent_tokenize(str(row["mdna"]), language="english")
sentence_count.append(len(sentences))
for _, s in enumerate(sentences):
corpus.append(s)
print(f"Number of sentences: {len(corpus)}")
# Load pre-embedded corpus
corpus_embeddings = np.load(embeddings_filename)["arr_0"]
print(f"Number of embeddings: {corpus_embeddings.shape[0]}")
# Load embedding model
model = SentenceTransformer(model_name)
model.max_seq_length = max_sequence_length
# Load speech to text model
asr = pipeline(
"automatic-speech-recognition", model=asr_model, feature_extractor=asr_model
)
def find_sentences(query, hits):
query_embedding = model.encode(query)
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=hits)
hits = hits[0]
output = pd.DataFrame(
columns=["Ticker", "Form type", "Filing date", "Text", "Score"]
)
for hit in hits:
corpus_id = hit["corpus_id"]
# Find source document based on sentence index
count = 0
for idx, c in enumerate(sentence_count):
count += c
if corpus_id > count - 1:
continue
else:
doc = df.iloc[idx]
new_row = {
"Ticker": doc["ticker"],
"Form type": doc["form_type"],
"Filing date": doc["filing_date"],
"Text": corpus[corpus_id][:80],
"Score": "{:.2f}".format(hit["score"]),
}
output = pd.concat([output, pd.DataFrame([new_row])], ignore_index=True)
break
return output
def process(input_selection, query, filepath, hits):
if input_selection == "speech":
speech, sampling_rate = load(filepath)
if sampling_rate != 16000:
speech = resample(speech, orig_sr=sampling_rate, target_sr=16000)
text = asr(speech)["text"]
else:
text = query
return text, find_sentences(text, hits)
# Gradio inputs
buttons = gr.Radio(
["text", "speech"], type="value", value="speech", label="Input selection"
)
text_query = gr.Textbox(
lines=1,
label="Text input",
value="The company is under investigation by tax authorities for potential fraud.",
)
mic = gr.Audio(
source="microphone", type="filepath", label="Speech input", optional=True
)
slider = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of hits")
# Gradio outputs
speech_query = gr.Textbox(type="text", label="Query string")
results = gr.Dataframe(
type="pandas",
headers=["Ticker", "Form type", "Filing date", "Text", "Score"],
label="Query results",
)
iface = gr.Interface(
theme="huggingface",
description="This Spaces lets you query a text corpus containing 2020 annual filings for all S&P500 companies. You can type a text query in English, or record an audio query in 21 languages. You can find a technical deep dive at https://www.youtube.com/watch?v=YPme-gR0f80",
fn=process,
inputs=[buttons, text_query, mic, slider],
outputs=[speech_query, results],
examples=[
[
"speech",
"Nos ventes internationales ont significativement augmenté.",
"sales_16k_fr.wav",
3,
],
[
"speech",
"Le prix de l'énergie pourrait avoir un impact négatif dans le futur.",
"energy_16k_fr.wav",
3,
],
[
"speech",
"El precio de la energía podría tener un impacto negativo en el futuro.",
"energy_24k_es.wav",
3,
],
[
"speech",
"Mehrere Steuerbehörden untersuchen unser Unternehmen.",
"tax_24k_de.wav",
3,
],
],
)
iface.launch()
|