Spaces:
Sleeping
Sleeping
from typing import Dict, List, Union | |
from google.cloud import aiplatform | |
from google.protobuf import json_format | |
from google.protobuf.struct_pb2 import Value | |
import os | |
import re | |
import pandas as pd | |
import plotly.express as px | |
import plotly.graph_objects as go | |
import streamlit as st | |
import nltk | |
import json | |
import tempfile | |
# process of getting credentials | |
def get_credentials(): | |
creds_json_str = os.getenv("JSONSTR") # get json credentials stored as a string | |
if creds_json_str is None: | |
raise ValueError("GOOGLE_APPLICATION_CREDENTIALS_JSON not found in environment") | |
# create a temporary file | |
with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json") as temp: | |
temp.write(creds_json_str) # write in json format | |
temp_filename = temp.name | |
return temp_filename | |
# pass | |
os.environ["GOOGLE_APPLICATION_CREDENTIALS"]= get_credentials() | |
max_seq_length = 2048 | |
dtype = None | |
load_in_4bit = True | |
# Check if 'punkt' is already downloaded, otherwise download it | |
try: | |
nltk.data.find('tokenizers/punkt') | |
except LookupError: | |
nltk.download('punkt') | |
text_split_tokenizer = nltk.data.load('tokenizers/punkt/english.pickle') | |
def predict_custom_trained_model_sample( | |
project: str, | |
endpoint_id: str, | |
instances: Union[Dict, List[Dict]], | |
location: str = "us-east4", | |
api_endpoint: str = "us-east4-aiplatform.googleapis.com", | |
) -> List[str]: | |
""" | |
`instances` can be either single instance of type dict or a list | |
of instances. | |
""" | |
client_options = {"api_endpoint": api_endpoint} | |
client = aiplatform.gapic.PredictionServiceClient(client_options=client_options) | |
instances = instances if isinstance(instances, list) else [instances] | |
instances = [ | |
json_format.ParseDict(instance_dict, Value()) for instance_dict in instances | |
] | |
parameters_dict = {} | |
parameters = json_format.ParseDict(parameters_dict, Value()) | |
endpoint = client.endpoint_path( | |
project=project, location=location, endpoint=endpoint_id | |
) | |
response = client.predict( | |
endpoint=endpoint, instances=instances, parameters=parameters | |
) | |
predictions_list = [] | |
predictions = response.predictions | |
for prediction in predictions: | |
if isinstance(prediction, str): | |
clean_prediction = re.sub(r'(\n|Origin|###|Optimization|Response:)', '', prediction) | |
split_predictions = clean_prediction.split() | |
predictions_list.extend(split_predictions) | |
else: | |
print(" prediction (unknown type, skipping):", prediction) | |
return [emotion for emotion in predictions_list if emotion in d_emotion.values()] | |
d_emotion = {0: 'admiration', 1: 'amusement', 2: 'anger', 3: 'annoyance', 4: 'approval', 5: 'caring', 6: 'confusion', | |
7: 'curiosity', 8: 'desire', 9: 'disappointment', 10: 'disapproval', 11: 'disgust', 12: 'embarrassment', | |
13: 'excitement', 14: 'fear', 15: 'gratitude', 16: 'grief', 17: 'joy', 18: 'love', 19: 'nervousness', | |
20: 'optimism', 21: 'pride', 22: 'realization', 23: 'relief', 24: 'remorse', 25: 'sadness', 26: 'surprise', | |
27: 'neutral'} | |
st.write("Write or paste any number of document texts to analyse the emotion percentage with your document") | |
user_input = st.text_area('Enter Text to Analyze') | |
button = st.button("Analyze") | |
if user_input and button: | |
alpaca_prompt = """Below is a conversation between a human and an AI agent. write a response based on the input. | |
### Instruction: | |
predict the emotion word or words | |
### Input: | |
{} | |
### Response: | |
""" | |
instances = [] | |
input_array = text_split_tokenizer.tokenize(user_input) | |
for sentence in input_array: | |
formatted_input = alpaca_prompt.format(sentence.strip()) | |
instance = { | |
"inputs": formatted_input, | |
"parameters": { | |
"max_new_tokens": 4, | |
"temperature": 0.00001, | |
"top_p": 0.9, | |
"top_k": 10 | |
} | |
} | |
instances.append(instance) | |
predictions = predict_custom_trained_model_sample( | |
project=os.environ["project"], | |
endpoint_id=os.environ["endpoint_id"], | |
location=os.environ["location"], | |
instances=instances | |
) | |
emotion_counts = pd.Series(predictions).value_counts(normalize=True).reset_index() | |
emotion_counts.columns = ['Emotion', 'Percentage'] | |
emotion_counts['Percentage'] *= 100 # Convert to percentage | |
fig_pie = px.pie(emotion_counts, values='Percentage', names='Emotion', title='Percentage of Emotions in Given Text') | |
fig_pie.update_traces(textposition='inside', textinfo='percent+label') | |
def get_emotion_chart(predictions): | |
emotion_counts = pd.Series(predictions).value_counts().reset_index() | |
emotion_counts.columns = ['Emotion', 'Count'] | |
fig_bar = go.Figure() | |
fig_bar.add_trace(go.Bar( | |
x=emotion_counts['Emotion'], | |
y=emotion_counts['Count'], | |
marker_color='indianred' | |
)) | |
fig_bar.update_layout(title='Count of Each Emotion in Given Text', xaxis_title='Emotion', yaxis_title='Count') | |
return fig_bar | |
fig_bar = get_emotion_chart(predictions) | |
def get_emotion_heatmap(predictions): | |
# Create a matrix for heatmap | |
# Count occurrences of each emotion | |
emotion_counts = pd.Series(predictions).value_counts().reset_index() | |
emotion_counts.columns = ['Emotion', 'Count'] | |
heatmap_matrix = pd.DataFrame(0, index=d_emotion.values(), columns=d_emotion.values()) | |
for index, row in emotion_counts.iterrows(): | |
heatmap_matrix.at[row['Emotion'], row['Emotion']] = row['Count'] | |
fig = go.Figure(data=go.Heatmap( | |
z=heatmap_matrix.values, | |
x=heatmap_matrix.columns.tolist(), | |
y=heatmap_matrix.index.tolist(), | |
text=heatmap_matrix.values, | |
hovertemplate="Count: %{text}", | |
colorscale='Viridis' | |
)) | |
fig.update_layout(title='Emotion Heatmap', xaxis_title='Predicted Emotion', yaxis_title='Predicted Emotion') | |
return fig | |
fig_dist = get_emotion_heatmap(predictions) | |
tab1, tab2, tab3 = st.tabs(["Emotion Analysis", "Emotion Counts Distribution", "Heatmap"]) | |
with tab1: | |
st.plotly_chart(fig_pie) | |
with tab2: | |
st.plotly_chart(fig_bar) | |
with tab3: | |
st.plotly_chart(fig_dist) |