John Graham Reynolds commited on
Commit
b657fae
·
1 Parent(s): d2eaab7

add code for building model and tokenizer

Browse files
Files changed (1) hide show
  1. model.py +34 -0
model.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mlflow
2
+ import torch
3
+ import streamlit as st
4
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
5
+
6
+ class InferenceBuilder:
7
+
8
+ def __init__(self):
9
+ # Load the necessary configuration from yaml
10
+ self.model_config = mlflow.models.ModelConfig(development_config="model_config.yaml")
11
+ self.cybersolve_config = self.model_config.get("cybersolve_config")
12
+
13
+ def load_tokenizer(self):
14
+ tokenizer_name = self.cybersolve_config.get("tokenizer_name")
15
+ # make sure we cache this so that it doesnt redownload each time
16
+ # cannot directly use @st.cache_resource on a method (function within a class) that has a self argument
17
+ @st.cache_resource # https://docs.streamlit.io/develop/concepts/architecture/caching
18
+ def load_and_cache_tokenizer(tokenizer_name):
19
+ tokenizer = T5Tokenizer.from_pretrained(tokenizer_name) # CyberSolve uses the same tokenizer as the base FLAN-T5 model
20
+ return tokenizer
21
+
22
+ return load_and_cache_tokenizer(tokenizer_name)
23
+
24
+ def load_model(self):
25
+ model_name = self.cybersolve_config.get("model_name")
26
+ # make sure we cache this so that it doesnt redownload each time
27
+ # cannot directly use @st.cache_resource on a method (function within a class) that has a self argument
28
+ @st.cache_resource # https://docs.streamlit.io/develop/concepts/architecture/caching
29
+ def load_and_cache_model(model_name):
30
+ # model = T5ForConditionalGeneration.from_pretrained(model_name).to("cuda") # put the model on our Space's GPU
31
+ model = T5ForConditionalGeneration.from_pretrained(model_name) # move to GPU eventually
32
+ return model
33
+
34
+ return load_and_cache_model(model_name)