Spaces:
Sleeping
Sleeping
nssharmaofficial
commited on
Commit
•
2bdc1ae
1
Parent(s):
92562f2
Fix vocab filepath
Browse files- source/config.py +0 -1
- source/vocab.py +2 -51
source/config.py
CHANGED
@@ -20,4 +20,3 @@ class Config(object):
|
|
20 |
self.ENCODER_WEIGHT_FILE = 'source/weights/encoder-32B-512H-1L-e5.pt'
|
21 |
self.DECODER_WEIGHT_FILE = 'source/weights/decoder-32B-512H-1L-e5.pt'
|
22 |
|
23 |
-
self.ROOT = os.path.join(os.path.expanduser('~'), 'Huggingface', 'ImageCaption')
|
|
|
20 |
self.ENCODER_WEIGHT_FILE = 'source/weights/encoder-32B-512H-1L-e5.pt'
|
21 |
self.DECODER_WEIGHT_FILE = 'source/weights/decoder-32B-512H-1L-e5.pt'
|
22 |
|
|
source/vocab.py
CHANGED
@@ -41,53 +41,6 @@ class Vocab:
|
|
41 |
"""
|
42 |
self.counter.update(self.splitter(sentence))
|
43 |
|
44 |
-
def build_vocab(self, vocab_size: int, file_name: str):
|
45 |
-
""" Build vocabulary dictionaries word2index and index2word from a text file at config.ROOT path
|
46 |
-
|
47 |
-
Args:
|
48 |
-
|
49 |
-
vocab_size (int): size of vocabulary (including 4 predefined tokens: <pad>, <sos>, <eos>, <unk>)
|
50 |
-
|
51 |
-
file_name (str): name of the text file from which the vocabulary will be built.
|
52 |
-
Note: the lines in file are assumed to be in form: 'word SPACE index' and
|
53 |
-
it asssumes a header line (for example: 'captions.txt')
|
54 |
-
"""
|
55 |
-
|
56 |
-
filepath = os.path.join(self.config.ROOT, file_name)
|
57 |
-
|
58 |
-
try:
|
59 |
-
with open(filepath, 'r', encoding='utf-8') as file:
|
60 |
-
for i, line in enumerate(file):
|
61 |
-
# ignore header line
|
62 |
-
if i == 0:
|
63 |
-
continue
|
64 |
-
caption = line.strip().lower().split(",", 1)[1] # id=0, caption=1
|
65 |
-
self.add_sentence(caption)
|
66 |
-
except Exception as e:
|
67 |
-
print(f"Error processing file {filepath}: {e}")
|
68 |
-
return
|
69 |
-
|
70 |
-
# adding predefined tokens in the vocabulary
|
71 |
-
self._add_predefined_tokens()
|
72 |
-
|
73 |
-
words = self.counter.most_common(vocab_size - 4)
|
74 |
-
# (index + 4) because first 4 tokens are the predefined ones
|
75 |
-
for index, (word, _) in enumerate(words, start=4):
|
76 |
-
self.word2index[word] = index
|
77 |
-
self.index2word[index] = word
|
78 |
-
|
79 |
-
self.size = len(self.word2index)
|
80 |
-
|
81 |
-
# adding predefined tokens in the vocabulary
|
82 |
-
self.index2word[self.PADDING_INDEX] = '<pad>'
|
83 |
-
self.word2index['<pad>'] = self.PADDING_INDEX
|
84 |
-
self.index2word[self.SOS] = '<sos>'
|
85 |
-
self.word2index['<sos>'] = self.SOS
|
86 |
-
self.index2word[self.EOS] = '<eos>'
|
87 |
-
self.word2index['<eos>'] = self.EOS
|
88 |
-
self.index2word[self.UNKNOWN_WORD_INDEX] = '<unk>'
|
89 |
-
self.word2index['<unk>'] = self.UNKNOWN_WORD_INDEX
|
90 |
-
|
91 |
def word_to_index(self, word: str) -> int:
|
92 |
""" Map word to index from word2index dictionary in vocabulary
|
93 |
|
@@ -116,16 +69,14 @@ class Vocab:
|
|
116 |
except KeyError:
|
117 |
return self.index2word[self.UNKNOWN_WORD_INDEX]
|
118 |
|
119 |
-
def load_vocab(self,
|
120 |
-
""" Load the word2index and index2word dictionaries from a text file
|
121 |
|
122 |
Args:
|
123 |
file_name (str): name of the text file where the vocabulary is saved (i.e 'word2index.txt')
|
124 |
Note: the lines in file are assumed to be in form: 'word SPACE index' and it asssumes a header line
|
125 |
"""
|
126 |
|
127 |
-
filepath = os.path.join(self.config.ROOT, file_name)
|
128 |
-
|
129 |
self.word2index = dict()
|
130 |
self.index2word = dict()
|
131 |
|
|
|
41 |
"""
|
42 |
self.counter.update(self.splitter(sentence))
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
def word_to_index(self, word: str) -> int:
|
45 |
""" Map word to index from word2index dictionary in vocabulary
|
46 |
|
|
|
69 |
except KeyError:
|
70 |
return self.index2word[self.UNKNOWN_WORD_INDEX]
|
71 |
|
72 |
+
def load_vocab(self, filepath: str):
|
73 |
+
""" Load the word2index and index2word dictionaries from a text file.
|
74 |
|
75 |
Args:
|
76 |
file_name (str): name of the text file where the vocabulary is saved (i.e 'word2index.txt')
|
77 |
Note: the lines in file are assumed to be in form: 'word SPACE index' and it asssumes a header line
|
78 |
"""
|
79 |
|
|
|
|
|
80 |
self.word2index = dict()
|
81 |
self.index2word = dict()
|
82 |
|