ndebuhr's picture
Setup the app code, requirements, and metadata
3c30e6f
raw
history blame contribute delete
No virus
4.36 kB
import apache_beam as beam
import gradio as gr
import huggingface_hub
import pandas as pd
import plotly.graph_objects as go
import spaces
import textwrap
import torch
import us
from apache_beam.options.pipeline_options import PipelineOptions, SetupOptions
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import logging
import os
import requests
MODEL_NAME = "google/gemma-2-2b-it"
PROMPT_TEMPLATE = """Write a succinct summary of the following weather alerts. Do not comment on missing information - just summarize the information provided/available.
```json
{}
```
Summary (In the state...):
"""
# Initialize an empty list to store weather alerts
alerts = []
# Define a transform for fetching weather alerts
class FetchWeatherAlerts(beam.DoFn):
def process(self, state):
logging.info(f"Fetching weather alerts for {state} from weather.gov")
url = f"https://api.weather.gov/alerts/active?area={state}"
response = requests.get(
url,
headers={
"User-Agent": "(Neal DeBuhr, https://huggingface.co./spaces/ndebuhr/streaming-llm-weather-alerts)",
"Accept": "application/geo+json",
},
)
if response.status_code == 200:
logging.info(f"Fetched weather alerts for {state} from weather.gov")
features = response.json()["features"]
alerts.append(
{
"features": [
{
"event": feature["properties"]["event"],
"headline": feature["properties"]["headline"],
"instruction": feature["properties"]["instruction"],
}
for feature in features
if feature["properties"]["messageType"] == "Alert"
],
"state": state,
}
)
pipeline_options = PipelineOptions()
# Save the main session state so that pickled functions and classes
# defined in __main__ can be unpickled
pipeline_options.view_as(SetupOptions).save_main_session = True
# Create and run the Apache Beam pipeline to fetch weather alerts
with beam.Pipeline(options=pipeline_options) as p:
(p
| "Create States" >> beam.Create([state.abbr for state in us.states.STATES])
| "Fetch Weather Alerts" >> beam.ParDo(FetchWeatherAlerts())
)
# Define a function to generate alert summaries using transformers and ZeroGPU
@spaces.GPU(duration=300)
def generate_summaries(alerts):
huggingface_hub.login(token=os.environ["HUGGINGFACE_TOKEN"])
device = torch.device("cuda")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)
for alert in alerts:
prompt = PROMPT_TEMPLATE.format(json.dumps(alert, indent=2))
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs, max_new_tokens=256, pad_token_id=tokenizer.eos_token_id
)
alert["summary"] = (
tokenizer.decode(outputs[0], skip_special_tokens=True)
.replace(prompt, "")
.strip()
)
return alerts
alerts = generate_summaries(alerts)
df = pd.DataFrame.from_dict(
[{"state": alert["state"], "summary": alert["summary"]} for alert in alerts]
)
def get_map():
def wrap_text(text, width=50):
return "<br>".join(textwrap.wrap(text, width=width))
df["wrapped_summary"] = df["summary"].apply(wrap_text)
fig = go.Figure(
go.Choropleth(
locations=df["state"],
z=[1 for _ in df["summary"]],
locationmode="USA-states",
colorscale=[
[0, "lightgrey"],
[1, "lightgrey"],
], # Single color for all states
showscale=False,
text=df["wrapped_summary"],
hoverinfo="text",
hovertemplate="%{text}<extra></extra>",
)
)
fig.update_layout(title_text="Streaming LLM Weather Alerts", geo_scope="usa")
return fig
# Create Gradio interface
iface = gr.Interface(fn=get_map, inputs=None, outputs=gr.Plot())
# Launch the Gradio interface
iface.launch()