import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torchtext import data | |
from gensim.corpora import WikiCorpus | |
from transformers import GPT2Tokenizer, GPT2Model | |
from functions import * | |
# Define the model | |
# class GPT(nn.Module): | |
# def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers): | |
# super().__init__() | |
# self.embedding = nn.Embedding(vocab_size, embedding_dim) | |
# self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True) | |
# self.fc = nn.Linear(hidden_dim, vocab_size) | |
# self.gpt2 = model | |
# def forward(self, x): | |
# # Embed the input | |
# x = self.embedding(x) | |
# # Pass through the GPT2 model | |
# x = self.gpt2(x) | |
# # Pass through the LSTM | |
# x, _ = self.lstm(x) | |
# # Pass through the fully connected layer | |
# x = self.fc(x) | |
# return x | |
# Load the GPT2 model | |
print('load gpt2 model') | |
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') | |
model = GPT2Model.from_pretrained('gpt2') | |
# Load the data | |
print('load custom data') | |
# wiki_corpus_en = WikiCorpus('data/enwiki-latest-pages-articles.xml.bz2') | |
wiki_corpus_fr = WikiCorpus('data/frwiki-latest-pages-articles.xml.bz2') | |
# stackoverflow_corpus = data.TabularDataset('data/stackoverflow.csv', format='csv', fields=['text']) | |
# Preprocess the data | |
print('Preprocess the data') | |
# wiki_data_en = [text for text in wiki_corpus_en] | |
wiki_data_fr = [text for text in wiki_corpus_fr] | |
# stackoverflow_data = [text for text in stackoverflow_corpus] | |
# Convert the data to a format compatible with PyTorch | |
print('Convert the data to a format compatible with PyTorch') | |
# wiki_data_en = torch.tensor(wiki_data_en) | |
wiki_data_fr = torch.tensor(wiki_data_fr) | |
# stackoverflow_data = torch.tensor(stackoverflow_data) | |
# Define the Adam optimizer | |
print('Define the Adam optimizer') | |
optimizer = optim.Adam(model.parameters(), lr=0.001) | |
# Define the loss function | |
print('Define the loss function') | |
criterion = nn.CrossEntropyLoss() | |
# Train the model | |
print('Train the model') | |
num_epochs=10 | |
labels = torch.tensor([0, 1, 1, 0, 0, 1, 0, 1, 0, 1]) | |
for epoch in range(num_epochs): | |
print('epoch: ' + epoch) | |
# Forward pass | |
# outputs = model(wiki_data, stackoverflow_data) | |
outputs = model(wiki_data_fr) | |
# Calculate the loss | |
loss = criterion(outputs, labels) | |
# Backward pass | |
loss.backward() | |
# Update the parameters | |
optimizer.step() | |
# Reset the gradients | |
optimizer.zero_grad() | |
# Evaluate the model | |
accuracy = evaluate(model, wiki_data_fr) | |
# Save the model weights and states | |
torch.save(model.state_dict(), 'model.pth') | |
# Adjust the learning rate | |
adjust_learning_rate(optimizer, epoch) | |
# Print the loss and accuracy | |
print('Epoch: {}, Loss: {:.4f}, Accuracy: {:.4f}'.format(epoch+1, loss.item(), accuracy)) |