John Graham Reynolds
add code for building model and tokenizer
b657fae
raw
history blame
1.77 kB
import mlflow
import torch
import streamlit as st
from transformers import T5ForConditionalGeneration, T5Tokenizer
class InferenceBuilder:
def __init__(self):
# Load the necessary configuration from yaml
self.model_config = mlflow.models.ModelConfig(development_config="model_config.yaml")
self.cybersolve_config = self.model_config.get("cybersolve_config")
def load_tokenizer(self):
tokenizer_name = self.cybersolve_config.get("tokenizer_name")
# make sure we cache this so that it doesnt redownload each time
# cannot directly use @st.cache_resource on a method (function within a class) that has a self argument
@st.cache_resource # https://docs.streamlit.io/develop/concepts/architecture/caching
def load_and_cache_tokenizer(tokenizer_name):
tokenizer = T5Tokenizer.from_pretrained(tokenizer_name) # CyberSolve uses the same tokenizer as the base FLAN-T5 model
return tokenizer
return load_and_cache_tokenizer(tokenizer_name)
def load_model(self):
model_name = self.cybersolve_config.get("model_name")
# make sure we cache this so that it doesnt redownload each time
# cannot directly use @st.cache_resource on a method (function within a class) that has a self argument
@st.cache_resource # https://docs.streamlit.io/develop/concepts/architecture/caching
def load_and_cache_model(model_name):
# model = T5ForConditionalGeneration.from_pretrained(model_name).to("cuda") # put the model on our Space's GPU
model = T5ForConditionalGeneration.from_pretrained(model_name) # move to GPU eventually
return model
return load_and_cache_model(model_name)