Spaces:
Build error
Build error
Upload 2 files
Browse files
app.py
CHANGED
@@ -10,7 +10,8 @@ from utils import (
|
|
10 |
clean_entities,
|
11 |
create_dense_embeddings,
|
12 |
create_sparse_embeddings,
|
13 |
-
|
|
|
14 |
format_query,
|
15 |
get_flan_alpaca_xl_model,
|
16 |
generate_alpaca_ner_prompt,
|
@@ -70,9 +71,12 @@ with col1:
|
|
70 |
if ner_choice == "Alpaca":
|
71 |
ner_prompt = generate_alpaca_ner_prompt(query_text)
|
72 |
entity_text = generate_entities_flan_alpaca_inference_api(ner_prompt)
|
73 |
-
company_ent, quarter_ent, year_ent = format_entities_flan_alpaca(
|
|
|
|
|
74 |
else:
|
75 |
-
company_ent
|
|
|
76 |
|
77 |
ticker_index, quarter_index, year_index = clean_entities(
|
78 |
company_ent, quarter_ent, year_ent
|
|
|
10 |
clean_entities,
|
11 |
create_dense_embeddings,
|
12 |
create_sparse_embeddings,
|
13 |
+
extract_quarter_year,
|
14 |
+
extract_ticker_spacy,
|
15 |
format_query,
|
16 |
get_flan_alpaca_xl_model,
|
17 |
generate_alpaca_ner_prompt,
|
|
|
71 |
if ner_choice == "Alpaca":
|
72 |
ner_prompt = generate_alpaca_ner_prompt(query_text)
|
73 |
entity_text = generate_entities_flan_alpaca_inference_api(ner_prompt)
|
74 |
+
company_ent, quarter_ent, year_ent = format_entities_flan_alpaca(
|
75 |
+
entity_text
|
76 |
+
)
|
77 |
else:
|
78 |
+
company_ent = extract_ticker_spacy(query_text, ner_model)
|
79 |
+
quarter_ent, year_ent = extract_quarter_year(query_text)
|
80 |
|
81 |
ticker_index, quarter_index, year_index = clean_entities(
|
82 |
company_ent, quarter_ent, year_ent
|
utils.py
CHANGED
@@ -5,6 +5,7 @@ import requests
|
|
5 |
import openai
|
6 |
import pandas as pd
|
7 |
import spacy
|
|
|
8 |
import streamlit_scrollable_textbox as stx
|
9 |
import torch
|
10 |
from sentence_transformers import SentenceTransformer
|
@@ -33,13 +34,17 @@ def get_data():
|
|
33 |
|
34 |
@st.experimental_singleton
|
35 |
def get_spacy_model():
|
36 |
-
return spacy.load("
|
37 |
|
38 |
|
39 |
@st.experimental_singleton
|
40 |
def get_flan_alpaca_xl_model():
|
41 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(
|
42 |
-
|
|
|
|
|
|
|
|
|
43 |
return model, tokenizer
|
44 |
|
45 |
|
@@ -478,6 +483,7 @@ Answer:?"""
|
|
478 |
|
479 |
# Entity Extraction
|
480 |
|
|
|
481 |
def generate_alpaca_ner_prompt(query):
|
482 |
prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Use the following guidelines to extract the entities representing the Company, Quarter, and Year in the sentence.
|
483 |
|
@@ -515,19 +521,27 @@ Company - Cisco, Quarter - none, Year - 2016
|
|
515 |
### Response:"""
|
516 |
return prompt
|
517 |
|
|
|
518 |
def generate_entities_flan_alpaca_inference_api(prompt):
|
519 |
API_URL = "https://api-inference.huggingface.co/models/declare-lab/flan-alpaca-xl"
|
520 |
API_TOKEN = st.secrets["hg_key"]
|
521 |
headers = {"Authorization": f"Bearer {API_TOKEN}"}
|
522 |
payload = {
|
523 |
"inputs": prompt,
|
524 |
-
"parameters": {
|
525 |
-
|
|
|
|
|
|
|
|
|
526 |
}
|
527 |
try:
|
528 |
data = json.dumps(payload)
|
|
|
529 |
response = requests.request("POST", API_URL, data=data)
|
530 |
-
output = json.loads(response.content.decode("utf-8"))[0][
|
|
|
|
|
531 |
except:
|
532 |
output = ""
|
533 |
print(output)
|
@@ -536,7 +550,7 @@ def generate_entities_flan_alpaca_inference_api(prompt):
|
|
536 |
|
537 |
def generate_entities_flan_alpaca_checkpoint(model, tokenizer, prompt):
|
538 |
model_inputs = tokenizer(prompt, return_tensors="pt")
|
539 |
-
input_ids =
|
540 |
generation_output = model.generate(
|
541 |
input_ids=input_ids,
|
542 |
temperature=0.1,
|
@@ -547,9 +561,9 @@ def generate_entities_flan_alpaca_checkpoint(model, tokenizer, prompt):
|
|
547 |
return output
|
548 |
|
549 |
|
550 |
-
def format_entities_flan_alpaca(
|
551 |
"""
|
552 |
-
Extracts the text for each entity from the output generated by the
|
553 |
Flan-Alpaca model.
|
554 |
"""
|
555 |
try:
|
@@ -560,22 +574,22 @@ def format_entities_flan_alpaca(model_output):
|
|
560 |
year = None
|
561 |
try:
|
562 |
company = company_string.split(" - ")[1].lower()
|
563 |
-
company = None if company.lower() ==
|
564 |
except:
|
565 |
company = None
|
566 |
try:
|
567 |
quarter = quarter_string.split(" - ")[1]
|
568 |
-
quarter = None if quarter.lower() ==
|
569 |
|
570 |
except:
|
571 |
quarter = None
|
572 |
try:
|
573 |
year = year_string.split(" - ")[1]
|
574 |
-
year = None if year.lower() ==
|
575 |
|
576 |
except:
|
577 |
year = None
|
578 |
-
|
579 |
print((company, quarter, year))
|
580 |
return company, quarter, year
|
581 |
|
@@ -586,34 +600,27 @@ def extract_quarter_year(string):
|
|
586 |
if year_match:
|
587 |
year = year_match.group()
|
588 |
else:
|
589 |
-
|
590 |
|
591 |
# Extract quarter from string
|
592 |
quarter_match = re.search(r"Q\d", string)
|
593 |
if quarter_match:
|
594 |
quarter = "Q" + quarter_match.group()[1]
|
595 |
else:
|
596 |
-
|
597 |
|
598 |
return quarter, year
|
599 |
|
600 |
|
601 |
-
def
|
602 |
doc = model(query)
|
603 |
entities = {ent.label_: ent.text for ent in doc.ents}
|
|
|
604 |
if "ORG" in entities.keys():
|
605 |
company = entities["ORG"].lower()
|
606 |
-
if "DATE" in entities.keys():
|
607 |
-
quarter, year = extract_quarter_year(entities["DATE"])
|
608 |
-
return company, quarter, year
|
609 |
-
else:
|
610 |
-
return company, None, None
|
611 |
else:
|
612 |
-
|
613 |
-
|
614 |
-
return None, quarter, year
|
615 |
-
else:
|
616 |
-
return None, None, None
|
617 |
|
618 |
|
619 |
def clean_entities(company, quarter, year):
|
|
|
5 |
import openai
|
6 |
import pandas as pd
|
7 |
import spacy
|
8 |
+
import spacy_transformers
|
9 |
import streamlit_scrollable_textbox as stx
|
10 |
import torch
|
11 |
from sentence_transformers import SentenceTransformer
|
|
|
34 |
|
35 |
@st.experimental_singleton
|
36 |
def get_spacy_model():
|
37 |
+
return spacy.load("en_core_web_trf")
|
38 |
|
39 |
|
40 |
@st.experimental_singleton
|
41 |
def get_flan_alpaca_xl_model():
|
42 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
43 |
+
"/home/user/app/models/flan-alpaca-xl/"
|
44 |
+
)
|
45 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
46 |
+
"/home/user/app/models/flan-alpaca-xl/"
|
47 |
+
)
|
48 |
return model, tokenizer
|
49 |
|
50 |
|
|
|
483 |
|
484 |
# Entity Extraction
|
485 |
|
486 |
+
|
487 |
def generate_alpaca_ner_prompt(query):
|
488 |
prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Use the following guidelines to extract the entities representing the Company, Quarter, and Year in the sentence.
|
489 |
|
|
|
521 |
### Response:"""
|
522 |
return prompt
|
523 |
|
524 |
+
|
525 |
def generate_entities_flan_alpaca_inference_api(prompt):
|
526 |
API_URL = "https://api-inference.huggingface.co/models/declare-lab/flan-alpaca-xl"
|
527 |
API_TOKEN = st.secrets["hg_key"]
|
528 |
headers = {"Authorization": f"Bearer {API_TOKEN}"}
|
529 |
payload = {
|
530 |
"inputs": prompt,
|
531 |
+
"parameters": {
|
532 |
+
"do_sample": True,
|
533 |
+
"temperature": 0.1,
|
534 |
+
"max_length": 80,
|
535 |
+
},
|
536 |
+
"options": {"use_cache": False, "wait_for_model": True},
|
537 |
}
|
538 |
try:
|
539 |
data = json.dumps(payload)
|
540 |
+
# Key not used as headers=headers not passed
|
541 |
response = requests.request("POST", API_URL, data=data)
|
542 |
+
output = json.loads(response.content.decode("utf-8"))[0][
|
543 |
+
"generated_text"
|
544 |
+
]
|
545 |
except:
|
546 |
output = ""
|
547 |
print(output)
|
|
|
550 |
|
551 |
def generate_entities_flan_alpaca_checkpoint(model, tokenizer, prompt):
|
552 |
model_inputs = tokenizer(prompt, return_tensors="pt")
|
553 |
+
input_ids = model_inputs["input_ids"]
|
554 |
generation_output = model.generate(
|
555 |
input_ids=input_ids,
|
556 |
temperature=0.1,
|
|
|
561 |
return output
|
562 |
|
563 |
|
564 |
+
def format_entities_flan_alpaca(values):
|
565 |
"""
|
566 |
+
Extracts the text for each entity from the output generated by the
|
567 |
Flan-Alpaca model.
|
568 |
"""
|
569 |
try:
|
|
|
574 |
year = None
|
575 |
try:
|
576 |
company = company_string.split(" - ")[1].lower()
|
577 |
+
company = None if company.lower() == "none" else company
|
578 |
except:
|
579 |
company = None
|
580 |
try:
|
581 |
quarter = quarter_string.split(" - ")[1]
|
582 |
+
quarter = None if quarter.lower() == "none" else quarter
|
583 |
|
584 |
except:
|
585 |
quarter = None
|
586 |
try:
|
587 |
year = year_string.split(" - ")[1]
|
588 |
+
year = None if year.lower() == "none" else year
|
589 |
|
590 |
except:
|
591 |
year = None
|
592 |
+
|
593 |
print((company, quarter, year))
|
594 |
return company, quarter, year
|
595 |
|
|
|
600 |
if year_match:
|
601 |
year = year_match.group()
|
602 |
else:
|
603 |
+
year = None
|
604 |
|
605 |
# Extract quarter from string
|
606 |
quarter_match = re.search(r"Q\d", string)
|
607 |
if quarter_match:
|
608 |
quarter = "Q" + quarter_match.group()[1]
|
609 |
else:
|
610 |
+
quarter = None
|
611 |
|
612 |
return quarter, year
|
613 |
|
614 |
|
615 |
+
def extract_ticker_spacy(query, model):
|
616 |
doc = model(query)
|
617 |
entities = {ent.label_: ent.text for ent in doc.ents}
|
618 |
+
print(entities.keys())
|
619 |
if "ORG" in entities.keys():
|
620 |
company = entities["ORG"].lower()
|
|
|
|
|
|
|
|
|
|
|
621 |
else:
|
622 |
+
company = None
|
623 |
+
return company
|
|
|
|
|
|
|
624 |
|
625 |
|
626 |
def clean_entities(company, quarter, year):
|