English
File size: 3,794 Bytes
a1b5703
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import torch.nn as nn
import torch.optim as optim
import pickle
from torch.utils.data import Dataset, DataLoader
from safetensors.torch import load_file, save_file
import logging
import json

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Hyperparameters
sequence_length = 16
batch_size = 32
num_epochs = 1  # Continue training for 1 more epoch
learning_rate = 0.00001
embedding_dim = 256
hidden_dim = 512
num_layers = 2

# LSTM Model
class LSTMModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
        super(LSTMModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        embeds = self.embedding(x)
        lstm_out, _ = self.lstm(embeds)
        logits = self.fc(lstm_out[:, -1, :])
        return logits

# Load the model and vocabulary
logging.info('Loading the model and vocabulary...')
model_state_dict = load_file('lstm_model.safetensors')
with open('word2idx.pkl', 'rb') as f:
    word2idx = pickle.load(f)
with open('idx2word.pkl', 'rb') as f:
    idx2word = pickle.load(f)

vocab_size = len(word2idx)
model = LSTMModel(vocab_size, embedding_dim, hidden_dim, num_layers)
model.load_state_dict(model_state_dict)
model.train()

logging.info('Model and vocabulary loaded successfully.')

# Output the total number of parameters
total_params = sum(p.numel() for p in model.parameters())
logging.info(f'Total number of parameters: {total_params}')

# Read the text file
logging.info('Reading the text file...')
with open('text.txt', 'r') as file:
    text = file.read()
logging.info('Text file read successfully.')

# Preprocess the text
logging.info('Preprocessing the text...')
words = json.loads(text)
sequences = []
for i in range(len(words) - sequence_length):
    seq = words[i:i + sequence_length]
    label = words[i + sequence_length]
    sequences.append((seq, label))

logging.info(f'Number of sequences: {len(sequences)}')

# Dataset and DataLoader
class TextDataset(Dataset):
    def __init__(self, sequences, word2idx):
        self.sequences = sequences
        self.word2idx = word2idx

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq, label = self.sequences[idx]
        seq_idx = [self.word2idx.get(word, self.word2idx['<UNK>']) for word in seq]
        label_idx = self.word2idx.get(label, self.word2idx['<UNK>'])
        return torch.tensor(seq_idx, dtype=torch.long), torch.tensor(label_idx, dtype=torch.long)

logging.info('Creating dataset and dataloader...')
dataset = TextDataset(sequences, word2idx)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Continue training
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

logging.info('Starting continued training...')
for epoch in range(num_epochs):
    for batch_idx, batch in enumerate(dataloader):
        inputs, targets = batch
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_idx % 10 == 0:
            logging.info(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(dataloader)}], Loss: {loss.item():.4f}')

# Save the updated model
logging.info('Saving the updated model...')
save_file(model.state_dict(), 'lstm_model.safetensors')
with open('word2idx.pkl', 'wb') as f:
    pickle.dump(word2idx, f)
with open('idx2word.pkl', 'wb') as f:
    pickle.dump(idx2word, f)

logging.info('Updated model and vocabulary saved successfully.')