streetyogi commited on
Commit
3ecf051
·
1 Parent(s): 4288d88

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +28 -23
main.py CHANGED
@@ -1,51 +1,56 @@
1
  from fastapi import FastAPI
2
  from fastapi.staticfiles import StaticFiles
3
  from fastapi.responses import FileResponse
4
- from transformers import T5Tokenizer, T5ForCausalLM, Trainer, TrainingArguments
5
 
6
  app = FastAPI()
7
 
8
  # Initialize the tokenizer and model
9
- tokenizer = T5Tokenizer.from_pretrained("t5-base")
10
- model = T5ForCausalLM.from_pretrained("t5-base")
11
 
 
12
  with open("cyberpunk_lore.txt", "r") as f:
13
- dataset = f.read()
 
 
14
 
15
- # Tokenize the dataset
16
- input_ids = tokenizer.batch_encode_plus(dataset, return_tensors="pt")["input_ids"]
17
-
18
- # Set up training arguments
19
  training_args = TrainingArguments(
20
- output_dir='./results',
21
- overwrite_output_dir=True,
22
- num_train_epochs=5,
23
- per_device_train_batch_size=1,
24
  save_steps=10_000,
25
  save_total_limit=2,
26
  )
27
 
28
- # Create a Trainer
29
  trainer = Trainer(
30
  model=model,
31
  args=training_args,
32
- train_dataset=input_ids,
33
- eval_dataset=input_ids
34
  )
35
 
36
- # Fine-tune the model
37
  trainer.train()
38
 
39
- # Create the inference pipeline
40
- pipe_flan = pipeline("text2text-generation", model=model)
41
 
42
- @app.get("/infer_t5")
43
- def t5(input):
44
- output = pipe_flan(input)
45
- return {"output": output[0]["generated_text"]}
46
 
47
- app.mount("/", StaticFiles(directory="static", html=True), name="static")
 
 
 
 
 
48
 
 
 
 
 
49
  @app.get("/")
50
  def index() -> FileResponse:
51
  return FileResponse(path="/app/static/index.html", media_type="text/html")
 
1
  from fastapi import FastAPI
2
  from fastapi.staticfiles import StaticFiles
3
  from fastapi.responses import FileResponse
4
+ from transformers import BertTokenizer, BertForMaskedLM, Trainer, TrainingArguments
5
 
6
  app = FastAPI()
7
 
8
  # Initialize the tokenizer and model
9
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
10
+ model = BertForMaskedLM.from_pretrained("bert-base-uncased")
11
 
12
+ # Prepare the training data
13
  with open("cyberpunk_lore.txt", "r") as f:
14
+ train_data = f.read()
15
+ train_data = train_data.split("\n")
16
+ train_data = [tokenizer.encode(text, return_tensors="pt") for text in train_data]
17
 
18
+ # Define the training arguments
 
 
 
19
  training_args = TrainingArguments(
20
+ output_dir="./results",
21
+ per_device_train_batch_size=16,
 
 
22
  save_steps=10_000,
23
  save_total_limit=2,
24
  )
25
 
26
+ # Create the trainer
27
  trainer = Trainer(
28
  model=model,
29
  args=training_args,
30
+ train_dataset=train_data,
31
+ eval_dataset=train_data,
32
  )
33
 
34
+ # Start the training
35
  trainer.train()
36
 
37
+ # Save the fine-tuned model
38
+ trainer.save_model('./results')
39
 
40
+ # Load the fine-tuned model
41
+ model = trainer.get_model()
 
 
42
 
43
+ # Create the inference endpoint
44
+ @app.post("/infer")
45
+ def infer(input: str):
46
+ input_ids = tokenizer.encode(input, return_tensors="pt")
47
+ output = model(input_ids)[0]
48
+ return {"output": output}
49
 
50
+ @app.get("/")
51
+ def index() -> FileResponse:
52
+ return FileResponse(path="/app/static/index.html", media_type="text/html")
53
+
54
  @app.get("/")
55
  def index() -> FileResponse:
56
  return FileResponse(path="/app/static/index.html", media_type="text/html")