army_doctrine / app.py
ciaochris's picture
Update app.py
dea4868 verified
raw
history blame
5.51 kB
# Copyright 2024 Christopher Woodyard
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gradio as gr
import os
from groq import Groq
from functools import lru_cache
import time
import requests
import json
import random
import traceback
# Initialize Groq client
client = Groq(
api_key=os.environ.get("GROQ_API_KEY"),
)
# Load Army Doctrine snippets
try:
with open('army_doctrine_snippets.json', 'r') as f:
DOCTRINE_SNIPPETS = json.load(f)['snippets']
except FileNotFoundError:
print("Error: army_doctrine_snippets.json file not found.")
DOCTRINE_SNIPPETS = []
except json.JSONDecodeError:
print("Error: army_doctrine_snippets.json is not a valid JSON file.")
DOCTRINE_SNIPPETS = []
@lru_cache(maxsize=100)
def cached_generate_response(query, role, mission_type):
try:
# Ensure inputs are strings and not None
query = str(query) if query is not None else ""
role = str(role) if role is not None else ""
mission_type = str(mission_type) if mission_type is not None else ""
# Prepare a context with relevant doctrine snippets
relevant_snippets = [
snippet for snippet in DOCTRINE_SNIPPETS
if any(keyword in query.lower() for keyword in snippet['keywords'])
or any(keyword in role.lower() for keyword in snippet['keywords'])
or any(keyword in mission_type.lower() for keyword in snippet['keywords'])
]
if not relevant_snippets:
relevant_snippets = random.sample(DOCTRINE_SNIPPETS, min(3, len(DOCTRINE_SNIPPETS)))
context = "\n".join([snippet['text'] for snippet in relevant_snippets[:3]]) # Use top 3 relevant snippets
prompt = f"""You are an AI assistant specialized in U.S. Army Doctrine, designed to assist soldiers in the field.
Provide a concise, accurate, and actionable response to the following query, based on official U.S. Army Doctrine.
Your response should be clear, direct, and suitable for field conditions.
User's Role: {role}
Mission Type: {mission_type}
Relevant Doctrine Context:
{context}
Soldier's Query: {query}
Response:"""
chat_completion = client.chat.completions.create(
messages=[
{
"role": "system",
"content": "You are a U.S. Army Doctrine expert assistant. Provide clear, concise answers based on official doctrine.",
},
{
"role": "user",
"content": prompt,
}
],
model="llama-3.1-8b-instant",
temperature=0.2,
max_tokens=300,
)
return chat_completion.choices[0].message.content
except requests.exceptions.RequestException as e:
return f"Network error: {str(e)}"
except Exception as e:
return f"An error occurred in cached_generate_response: {str(e)}\n{traceback.format_exc()}"
def generate_response(query, role, mission_type):
debug_info = f"Debug Info:\nQuery: {repr(query)}\nRole: {repr(role)}\nMission Type: {repr(mission_type)}\n\n"
if not query or not role or not mission_type:
return debug_info + "Error: Please provide all required inputs (query, role, and mission type)."
max_retries = 3
retry_delay = 1 # second
for attempt in range(max_retries):
try:
response = cached_generate_response(query, role, mission_type)
return debug_info + response + "\n\nNote: This is AI-generated guidance based on simplified U.S. Army Doctrine concepts. It is for educational purposes only. Always refer to official sources and follow your chain of command for actual military operations."
except Exception as e:
if attempt < max_retries - 1:
time.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
else:
return debug_info + f"Failed to generate response after {max_retries} attempts. Error: {str(e)}\n{traceback.format_exc()}"
# Create the Gradio interface
iface = gr.Interface(
fn=generate_response,
inputs=[
gr.Textbox(lines=2, placeholder="Enter your field situation or doctrine query here...", label="Query"),
gr.Dropdown(
choices=["Squad Leader", "Platoon Leader", "Company Commander"],
label="Your Role",
value="Squad Leader"
),
gr.Radio(
choices=["Offensive", "Defensive", "Stability", "Support"],
label="Mission Type",
value="Defensive"
),
],
outputs="text",
title="Vers3Dynamics U.S. Army Doctrine Field Assistant",
description="Get quick, doctrine-based guidance for field situations. Always refer to official sources and follow your chain of command for actual military operations.",
)
# Launch the app
iface.launch()