File size: 4,356 Bytes
3c30e6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import apache_beam as beam
import gradio as gr
import huggingface_hub
import pandas as pd
import plotly.graph_objects as go
import spaces
import textwrap
import torch
import us

from apache_beam.options.pipeline_options import PipelineOptions, SetupOptions
from transformers import AutoTokenizer, AutoModelForCausalLM

import json
import logging
import os
import requests

MODEL_NAME = "google/gemma-2-2b-it"
PROMPT_TEMPLATE = """Write a succinct summary of the following weather alerts. Do not comment on missing information - just summarize the information provided/available.

```json
{}
```

Summary (In the state...):
"""

# Initialize an empty list to store weather alerts
alerts = []


# Define a transform for fetching weather alerts
class FetchWeatherAlerts(beam.DoFn):
    def process(self, state):
        logging.info(f"Fetching weather alerts for {state} from weather.gov")
        url = f"https://api.weather.gov/alerts/active?area={state}"
        response = requests.get(
            url,
            headers={
                "User-Agent": "(Neal DeBuhr, https://huggingface.co./spaces/ndebuhr/streaming-llm-weather-alerts)",
                "Accept": "application/geo+json",
            },
        )
        if response.status_code == 200:
            logging.info(f"Fetched weather alerts for {state} from weather.gov")
            features = response.json()["features"]
            alerts.append(
                {
                    "features": [
                        {
                            "event": feature["properties"]["event"],
                            "headline": feature["properties"]["headline"],
                            "instruction": feature["properties"]["instruction"],
                        }
                        for feature in features
                        if feature["properties"]["messageType"] == "Alert"
                    ],
                    "state": state,
                }
            )


pipeline_options = PipelineOptions()
# Save the main session state so that pickled functions and classes
# defined in __main__ can be unpickled
pipeline_options.view_as(SetupOptions).save_main_session = True

# Create and run the Apache Beam pipeline to fetch weather alerts
with beam.Pipeline(options=pipeline_options) as p:
    (p
        | "Create States" >> beam.Create([state.abbr for state in us.states.STATES])
        | "Fetch Weather Alerts" >> beam.ParDo(FetchWeatherAlerts())
    )


# Define a function to generate alert summaries using transformers and ZeroGPU
@spaces.GPU(duration=300)
def generate_summaries(alerts):
    huggingface_hub.login(token=os.environ["HUGGINGFACE_TOKEN"])
    device = torch.device("cuda")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)
    for alert in alerts:
        prompt = PROMPT_TEMPLATE.format(json.dumps(alert, indent=2))
        inputs = tokenizer(prompt, return_tensors="pt").to(device)

        with torch.no_grad():
            outputs = model.generate(
                **inputs, max_new_tokens=256, pad_token_id=tokenizer.eos_token_id
            )

        alert["summary"] = (
            tokenizer.decode(outputs[0], skip_special_tokens=True)
            .replace(prompt, "")
            .strip()
        )
    return alerts


alerts = generate_summaries(alerts)

df = pd.DataFrame.from_dict(
    [{"state": alert["state"], "summary": alert["summary"]} for alert in alerts]
)


def get_map():
    def wrap_text(text, width=50):
        return "<br>".join(textwrap.wrap(text, width=width))

    df["wrapped_summary"] = df["summary"].apply(wrap_text)

    fig = go.Figure(
        go.Choropleth(
            locations=df["state"],
            z=[1 for _ in df["summary"]],
            locationmode="USA-states",
            colorscale=[
                [0, "lightgrey"],
                [1, "lightgrey"],
            ],  # Single color for all states
            showscale=False,
            text=df["wrapped_summary"],
            hoverinfo="text",
            hovertemplate="%{text}<extra></extra>",
        )
    )

    fig.update_layout(title_text="Streaming LLM Weather Alerts", geo_scope="usa")

    return fig


# Create Gradio interface
iface = gr.Interface(fn=get_map, inputs=None, outputs=gr.Plot())

# Launch the Gradio interface
iface.launch()