import gradio as gr from utils import checkpoints, load_model, log_perplexity class ModelManager: """Class to manage model loading and perplexity calculation state.""" def __init__(self): self.loaded_models = None def load_models(self, checkpoint_input_str: str) -> str: """Load models from a comma-separated string of checkpoint names.""" checkpoint_list = [ c.strip() for c in checkpoint_input_str.split(",") if c.strip() ] if not checkpoint_list: return "Please enter at least one model checkpoint name." try: self.loaded_models = load_model(checkpoint_list) return "Models loaded successfully!" except Exception as e: return f"Model loading failed: {e}" def calculate_perplexity(self) -> dict | str: """Calculate perplexity using the loaded models.""" if self.loaded_models is None: return "Please load models first." try: result = log_perplexity() return result except Exception as e: return f"Perplexity calculation failed: {e}" def create_interface() -> gr.Blocks: """Create and return the Gradio interface.""" manager = ModelManager() with gr.Blocks() as demo: gr.Markdown("# LLM PPL") checkpoint_input = gr.Textbox( label="Checkpoints", value=", ".join(checkpoints), ) load_btn = gr.Button("Load Models", variant="primary") perplexity_btn = gr.Button("Compute PPL") load_output = gr.Textbox(label="Model Loading Status", interactive=False) perplexity_output = gr.JSON(label="PPL Results") # Connect event handlers load_btn.click( fn=manager.load_models, inputs=checkpoint_input, outputs=load_output ) perplexity_btn.click(fn=manager.calculate_perplexity, outputs=perplexity_output) return demo if __name__ == "__main__": demo = create_interface() demo.launch()