Spaces:
Running
Running
torchmoji code
Browse files- .travis.yml +1 -1
- LICENSE +21 -0
- data/.gitkeep +1 -0
- data/Olympic/raw.pickle +3 -0
- data/PsychExp/raw.pickle +3 -0
- data/SCv1/raw.pickle +3 -0
- data/SCv2-GEN/raw.pickle +3 -0
- data/SE0714/raw.pickle +3 -0
- data/SS-Twitter/raw.pickle +3 -0
- data/SS-Youtube/raw.pickle +3 -0
- data/emoji_codes.json +67 -0
- data/kaggle-insults/raw.pickle +3 -0
- emoji_overview.png +0 -0
- environment.yml +41 -0
- examples/.gitkeep +1 -0
- examples/README.md +39 -0
- examples/__init__.py +0 -0
- examples/create_twitter_vocab.py +13 -0
- examples/dataset_split.py +59 -0
- examples/encode_texts.py +41 -0
- examples/example_helper.py +6 -0
- examples/finetune_insults_chain-thaw.py +44 -0
- examples/finetune_semeval_class-avg_f1.py +50 -0
- examples/finetune_youtube_last.py +35 -0
- examples/score_texts_emojis.py +85 -0
- examples/text_emojize.py +63 -0
- examples/tokenize_dataset.py +26 -0
- examples/vocab_extension.py +30 -0
- scripts/analyze_all_results.py +40 -0
- scripts/analyze_results.py +39 -0
- scripts/calculate_coverages.py +90 -0
- scripts/convert_all_datasets.py +110 -0
- scripts/download_weights.py +65 -0
- scripts/finetune_dataset.py +109 -0
- scripts/results/.gitkeep +1 -0
- setup.py +16 -0
- tests/test_finetuning.py +235 -0
- tests/test_helper.py +6 -0
- tests/test_sentence_tokenizer.py +113 -0
- tests/test_tokenizer.py +167 -0
- tests/test_word_generator.py +73 -0
.travis.yml
CHANGED
@@ -24,4 +24,4 @@ script:
|
|
24 |
- true # pytest --capture=sys # add other tests here
|
25 |
notifications:
|
26 |
on_success: change
|
27 |
-
on_failure: change # `always` will be the setting once code changes slow down
|
|
|
24 |
- true # pytest --capture=sys # add other tests here
|
25 |
notifications:
|
26 |
on_success: change
|
27 |
+
on_failure: change # `always` will be the setting once code changes slow down
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2017 Bjarke Felbo, Han Thi Nguyen, Thomas Wolf
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
data/.gitkeep
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
data/Olympic/raw.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:398d394ac1d7c2116166ca968bae9b1f9fd049f9e9281f05c94ae7b2ea97d427
|
3 |
+
size 227301
|
data/PsychExp/raw.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dc7d710f2ccd7e9d8e620be703a446ce7ec05818d5ce6afe43d1e6aa9ff4a8aa
|
3 |
+
size 3492229
|
data/SCv1/raw.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a65db490451dada57b88918a951d04082a51599d2cde24914f8c713312de89f5
|
3 |
+
size 868931
|
data/SCv2-GEN/raw.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:43ae3ea310130c2ca2089d60876ba6b08006d7f2e018a0519c4fdb7b166f992f
|
3 |
+
size 883467
|
data/SE0714/raw.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:66f0ecf48affe92bdacdeb64ab20c1c84b9990a3ac7b659a1a98aa29c9c4a064
|
3 |
+
size 126311
|
data/SS-Twitter/raw.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0ef34a4f0fe39b1bb45fcb72026bbf3b82ce2e2a14c13d39610b3b41f18fc98e
|
3 |
+
size 413660
|
data/SS-Youtube/raw.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:83ec15e393fb4f0dbb524946480de50e9baf9fef83a3e9eaf95caa3c425b87aa
|
3 |
+
size 396130
|
data/emoji_codes.json
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"0": ":joy:",
|
3 |
+
"1": ":unamused:",
|
4 |
+
"2": ":weary:",
|
5 |
+
"3": ":sob:",
|
6 |
+
"4": ":heart_eyes:",
|
7 |
+
"5": ":pensive:",
|
8 |
+
"6": ":ok_hand:",
|
9 |
+
"7": ":blush:",
|
10 |
+
"8": ":heart:",
|
11 |
+
"9": ":smirk:",
|
12 |
+
"10":":grin:",
|
13 |
+
"11":":notes:",
|
14 |
+
"12":":flushed:",
|
15 |
+
"13":":100:",
|
16 |
+
"14":":sleeping:",
|
17 |
+
"15":":relieved:",
|
18 |
+
"16":":relaxed:",
|
19 |
+
"17":":raised_hands:",
|
20 |
+
"18":":two_hearts:",
|
21 |
+
"19":":expressionless:",
|
22 |
+
"20":":sweat_smile:",
|
23 |
+
"21":":pray:",
|
24 |
+
"22":":confused:",
|
25 |
+
"23":":kissing_heart:",
|
26 |
+
"24":":hearts:",
|
27 |
+
"25":":neutral_face:",
|
28 |
+
"26":":information_desk_person:",
|
29 |
+
"27":":disappointed:",
|
30 |
+
"28":":see_no_evil:",
|
31 |
+
"29":":tired_face:",
|
32 |
+
"30":":v:",
|
33 |
+
"31":":sunglasses:",
|
34 |
+
"32":":rage:",
|
35 |
+
"33":":thumbsup:",
|
36 |
+
"34":":cry:",
|
37 |
+
"35":":sleepy:",
|
38 |
+
"36":":stuck_out_tongue_winking_eye:",
|
39 |
+
"37":":triumph:",
|
40 |
+
"38":":raised_hand:",
|
41 |
+
"39":":mask:",
|
42 |
+
"40":":clap:",
|
43 |
+
"41":":eyes:",
|
44 |
+
"42":":gun:",
|
45 |
+
"43":":persevere:",
|
46 |
+
"44":":imp:",
|
47 |
+
"45":":sweat:",
|
48 |
+
"46":":broken_heart:",
|
49 |
+
"47":":blue_heart:",
|
50 |
+
"48":":headphones:",
|
51 |
+
"49":":speak_no_evil:",
|
52 |
+
"50":":wink:",
|
53 |
+
"51":":skull:",
|
54 |
+
"52":":confounded:",
|
55 |
+
"53":":smile:",
|
56 |
+
"54":":stuck_out_tongue_winking_eye:",
|
57 |
+
"55":":angry:",
|
58 |
+
"56":":no_good:",
|
59 |
+
"57":":muscle:",
|
60 |
+
"58":":punch:",
|
61 |
+
"59":":purple_heart:",
|
62 |
+
"60":":sparkling_heart:",
|
63 |
+
"61":":blue_heart:",
|
64 |
+
"62":":grimacing:",
|
65 |
+
"63":":sparkles:"
|
66 |
+
}
|
67 |
+
|
data/kaggle-insults/raw.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2fbeca5470209163e04b6975fc5fb91889e79583fe6ff499f83966e36392fcda
|
3 |
+
size 1338159
|
emoji_overview.png
ADDED
environment.yml
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: torchMoji
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- _libgcc_mutex=0.1
|
7 |
+
- blas=1.0
|
8 |
+
- ca-certificates=2019.11.27
|
9 |
+
- certifi=2019.11.28
|
10 |
+
- cffi=1.13.2
|
11 |
+
- cudatoolkit=10.1.243
|
12 |
+
- intel-openmp=2019.4
|
13 |
+
- libedit=3.1.20181209
|
14 |
+
- libffi=3.2.1
|
15 |
+
- libgcc-ng=9.1.0
|
16 |
+
- libgfortran-ng=7.3.0
|
17 |
+
- libstdcxx-ng=9.1.0
|
18 |
+
- mkl=2018.0.3
|
19 |
+
- ncurses=6.1
|
20 |
+
- ninja=1.9.0
|
21 |
+
- nose=1.3.7
|
22 |
+
- numpy=1.13.1
|
23 |
+
- openssl=1.1.1d
|
24 |
+
- pip=19.3.1
|
25 |
+
- pycparser=2.19
|
26 |
+
- python=3.6.9
|
27 |
+
- pytorch=1.3.1
|
28 |
+
- readline=7.0
|
29 |
+
- scikit-learn=0.19.0
|
30 |
+
- scipy=0.19.1
|
31 |
+
- setuptools=42.0.2
|
32 |
+
- sqlite=3.30.1
|
33 |
+
- text-unidecode=1.0
|
34 |
+
- tk=8.6.8
|
35 |
+
- wheel=0.33.6
|
36 |
+
- xz=5.2.4
|
37 |
+
- zlib=1.2.11
|
38 |
+
- pip:
|
39 |
+
- emoji==0.4.5
|
40 |
+
prefix: /home/cbowdon/miniconda3/envs/torchMoji
|
41 |
+
|
examples/.gitkeep
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
examples/README.md
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# torchMoji examples
|
2 |
+
|
3 |
+
## Initialization
|
4 |
+
[create_twitter_vocab.py](create_twitter_vocab.py)
|
5 |
+
Create a new vocabulary from a tsv file.
|
6 |
+
|
7 |
+
[tokenize_dataset.py](tokenize_dataset.py)
|
8 |
+
Tokenize a given dataset using the prebuilt vocabulary.
|
9 |
+
|
10 |
+
[vocab_extension.py](vocab_extension.py)
|
11 |
+
Extend the given vocabulary using dataset-specific words.
|
12 |
+
|
13 |
+
[dataset_split.py](dataset_split.py)
|
14 |
+
Split a given dataset into training, validation and testing.
|
15 |
+
|
16 |
+
## Use pretrained model/architecture
|
17 |
+
[score_texts_emojis.py](score_texts_emojis.py)
|
18 |
+
Use torchMoji to score texts for emoji distribution.
|
19 |
+
|
20 |
+
[text_emojize.py](text_emojize.py)
|
21 |
+
Use torchMoji to output emoji visualization from a single text input (mapped from `emoji_overview.png`)
|
22 |
+
|
23 |
+
```sh
|
24 |
+
python examples/text_emojize.py --text "I love mom's cooking\!"
|
25 |
+
# => I love mom's cooking! 😋 😍 💓 💛 ❤
|
26 |
+
```
|
27 |
+
|
28 |
+
[encode_texts.py](encode_texts.py)
|
29 |
+
Use torchMoji to encode the text into 2304-dimensional feature vectors for further modeling/analysis.
|
30 |
+
|
31 |
+
## Transfer learning
|
32 |
+
[finetune_youtube_last.py](finetune_youtube_last.py)
|
33 |
+
Finetune the model on the SS-Youtube dataset using the 'last' method.
|
34 |
+
|
35 |
+
[finetune_insults_chain-thaw.py](finetune_insults_chain-thaw.py)
|
36 |
+
Finetune the model on the Kaggle insults dataset (from blog post) using the 'chain-thaw' method.
|
37 |
+
|
38 |
+
[finetune_semeval_class-avg_f1.py](finetune_semeval_class-avg_f1.py)
|
39 |
+
Finetune the model on the SemeEval emotion dataset using the 'full' method and evaluate using the class average F1 metric.
|
examples/__init__.py
ADDED
File without changes
|
examples/create_twitter_vocab.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Creates a vocabulary from a tsv file.
|
2 |
+
"""
|
3 |
+
|
4 |
+
import codecs
|
5 |
+
import example_helper
|
6 |
+
from torchmoji.create_vocab import VocabBuilder
|
7 |
+
from torchmoji.word_generator import TweetWordGenerator
|
8 |
+
|
9 |
+
with codecs.open('../../twitterdata/tweets.2016-09-01', 'rU', 'utf-8') as stream:
|
10 |
+
wg = TweetWordGenerator(stream)
|
11 |
+
vb = VocabBuilder(wg)
|
12 |
+
vb.count_all_words()
|
13 |
+
vb.save_vocab()
|
examples/dataset_split.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Split a given dataset into three different datasets: training, validation and
|
3 |
+
testing.
|
4 |
+
|
5 |
+
This is achieved by splitting the given list of sentences into three separate
|
6 |
+
lists according to either a given ratio (e.g. [0.7, 0.1, 0.2]) or by an
|
7 |
+
explicit enumeration. The sentences are also tokenised using the given
|
8 |
+
vocabulary.
|
9 |
+
|
10 |
+
Also splits a given list of dictionaries containing information about
|
11 |
+
each sentence.
|
12 |
+
|
13 |
+
An additional parameter can be set 'extend_with', which will extend the given
|
14 |
+
vocabulary with up to 'extend_with' tokens, taken from the training dataset.
|
15 |
+
'''
|
16 |
+
from __future__ import print_function, unicode_literals
|
17 |
+
import example_helper
|
18 |
+
import json
|
19 |
+
|
20 |
+
from torchmoji.sentence_tokenizer import SentenceTokenizer
|
21 |
+
|
22 |
+
DATASET = [
|
23 |
+
'I am sentence 0',
|
24 |
+
'I am sentence 1',
|
25 |
+
'I am sentence 2',
|
26 |
+
'I am sentence 3',
|
27 |
+
'I am sentence 4',
|
28 |
+
'I am sentence 5',
|
29 |
+
'I am sentence 6',
|
30 |
+
'I am sentence 7',
|
31 |
+
'I am sentence 8',
|
32 |
+
'I am sentence 9 newword',
|
33 |
+
]
|
34 |
+
|
35 |
+
INFO_DICTS = [
|
36 |
+
{'label': 'sentence 0'},
|
37 |
+
{'label': 'sentence 1'},
|
38 |
+
{'label': 'sentence 2'},
|
39 |
+
{'label': 'sentence 3'},
|
40 |
+
{'label': 'sentence 4'},
|
41 |
+
{'label': 'sentence 5'},
|
42 |
+
{'label': 'sentence 6'},
|
43 |
+
{'label': 'sentence 7'},
|
44 |
+
{'label': 'sentence 8'},
|
45 |
+
{'label': 'sentence 9'},
|
46 |
+
]
|
47 |
+
|
48 |
+
with open('../model/vocabulary.json', 'r') as f:
|
49 |
+
vocab = json.load(f)
|
50 |
+
st = SentenceTokenizer(vocab, 30)
|
51 |
+
|
52 |
+
# Split using the default split ratio
|
53 |
+
print(st.split_train_val_test(DATASET, INFO_DICTS))
|
54 |
+
|
55 |
+
# Split explicitly
|
56 |
+
print(st.split_train_val_test(DATASET,
|
57 |
+
INFO_DICTS,
|
58 |
+
[[0, 1, 2, 4, 9], [5, 6], [7, 8, 3]],
|
59 |
+
extend_with=1))
|
examples/encode_texts.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
""" Use torchMoji to encode texts into emotional feature vectors.
|
4 |
+
"""
|
5 |
+
from __future__ import print_function, division, unicode_literals
|
6 |
+
import json
|
7 |
+
|
8 |
+
from torchmoji.sentence_tokenizer import SentenceTokenizer
|
9 |
+
from torchmoji.model_def import torchmoji_feature_encoding
|
10 |
+
from torchmoji.global_variables import PRETRAINED_PATH, VOCAB_PATH
|
11 |
+
|
12 |
+
TEST_SENTENCES = ['I love mom\'s cooking',
|
13 |
+
'I love how you never reply back..',
|
14 |
+
'I love cruising with my homies',
|
15 |
+
'I love messing with yo mind!!',
|
16 |
+
'I love you and now you\'re just gone..',
|
17 |
+
'This is shit',
|
18 |
+
'This is the shit']
|
19 |
+
|
20 |
+
maxlen = 30
|
21 |
+
batch_size = 32
|
22 |
+
|
23 |
+
print('Tokenizing using dictionary from {}'.format(VOCAB_PATH))
|
24 |
+
with open(VOCAB_PATH, 'r') as f:
|
25 |
+
vocabulary = json.load(f)
|
26 |
+
st = SentenceTokenizer(vocabulary, maxlen)
|
27 |
+
tokenized, _, _ = st.tokenize_sentences(TEST_SENTENCES)
|
28 |
+
|
29 |
+
print('Loading model from {}.'.format(PRETRAINED_PATH))
|
30 |
+
model = torchmoji_feature_encoding(PRETRAINED_PATH)
|
31 |
+
print(model)
|
32 |
+
|
33 |
+
print('Encoding texts..')
|
34 |
+
encoding = model(tokenized)
|
35 |
+
|
36 |
+
print('First 5 dimensions for sentence: {}'.format(TEST_SENTENCES[0]))
|
37 |
+
print(encoding[0,:5])
|
38 |
+
|
39 |
+
# Now you could visualize the encodings to see differences,
|
40 |
+
# run a logistic regression classifier on top,
|
41 |
+
# or basically anything you'd like to do.
|
examples/example_helper.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Module import helper.
|
2 |
+
Modifies PATH in order to allow us to import the torchmoji directory.
|
3 |
+
"""
|
4 |
+
import sys
|
5 |
+
from os.path import abspath, dirname
|
6 |
+
sys.path.insert(0, dirname(dirname(abspath(__file__))))
|
examples/finetune_insults_chain-thaw.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Finetuning example.
|
2 |
+
|
3 |
+
Trains the torchMoji model on the kaggle insults dataset, using the 'chain-thaw'
|
4 |
+
finetuning method and the accuracy metric. See the blog post at
|
5 |
+
https://medium.com/@bjarkefelbo/what-can-we-learn-from-emojis-6beb165a5ea0
|
6 |
+
for more information. Note that results may differ a bit due to slight
|
7 |
+
changes in preprocessing and train/val/test split.
|
8 |
+
|
9 |
+
The 'chain-thaw' method does the following:
|
10 |
+
0) Load all weights except for the softmax layer. Extend the embedding layer if
|
11 |
+
necessary, initialising the new weights with random values.
|
12 |
+
1) Freeze every layer except the last (softmax) layer and train it.
|
13 |
+
2) Freeze every layer except the first layer and train it.
|
14 |
+
3) Freeze every layer except the second etc., until the second last layer.
|
15 |
+
4) Unfreeze all layers and train entire model.
|
16 |
+
"""
|
17 |
+
|
18 |
+
from __future__ import print_function
|
19 |
+
import example_helper
|
20 |
+
import json
|
21 |
+
from torchmoji.model_def import torchmoji_transfer
|
22 |
+
from torchmoji.global_variables import PRETRAINED_PATH
|
23 |
+
from torchmoji.finetuning import (
|
24 |
+
load_benchmark,
|
25 |
+
finetune)
|
26 |
+
|
27 |
+
|
28 |
+
DATASET_PATH = '../data/kaggle-insults/raw.pickle'
|
29 |
+
nb_classes = 2
|
30 |
+
|
31 |
+
with open('../model/vocabulary.json', 'r') as f:
|
32 |
+
vocab = json.load(f)
|
33 |
+
|
34 |
+
# Load dataset. Extend the existing vocabulary with up to 10000 tokens from
|
35 |
+
# the training dataset.
|
36 |
+
data = load_benchmark(DATASET_PATH, vocab, extend_with=10000)
|
37 |
+
|
38 |
+
# Set up model and finetune. Note that we have to extend the embedding layer
|
39 |
+
# with the number of tokens added to the vocabulary.
|
40 |
+
model = torchmoji_transfer(nb_classes, PRETRAINED_PATH, extend_embedding=data['added'])
|
41 |
+
print(model)
|
42 |
+
model, acc = finetune(model, data['texts'], data['labels'], nb_classes,
|
43 |
+
data['batch_size'], method='chain-thaw')
|
44 |
+
print('Acc: {}'.format(acc))
|
examples/finetune_semeval_class-avg_f1.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Finetuning example.
|
2 |
+
|
3 |
+
Trains the torchMoji model on the SemEval emotion dataset, using the 'last'
|
4 |
+
finetuning method and the class average F1 metric.
|
5 |
+
|
6 |
+
The 'last' method does the following:
|
7 |
+
0) Load all weights except for the softmax layer. Do not add tokens to the
|
8 |
+
vocabulary and do not extend the embedding layer.
|
9 |
+
1) Freeze all layers except for the softmax layer.
|
10 |
+
2) Train.
|
11 |
+
|
12 |
+
The class average F1 metric does the following:
|
13 |
+
1) For each class, relabel the dataset into binary classification
|
14 |
+
(belongs to/does not belong to this class).
|
15 |
+
2) Calculate F1 score for each class.
|
16 |
+
3) Compute the average of all F1 scores.
|
17 |
+
"""
|
18 |
+
|
19 |
+
from __future__ import print_function
|
20 |
+
import example_helper
|
21 |
+
import json
|
22 |
+
from torchmoji.finetuning import load_benchmark
|
23 |
+
from torchmoji.class_avg_finetuning import class_avg_finetune
|
24 |
+
from torchmoji.model_def import torchmoji_transfer
|
25 |
+
from torchmoji.global_variables import PRETRAINED_PATH
|
26 |
+
|
27 |
+
DATASET_PATH = '../data/SE0714/raw.pickle'
|
28 |
+
nb_classes = 3
|
29 |
+
|
30 |
+
with open('../model/vocabulary.json', 'r') as f:
|
31 |
+
vocab = json.load(f)
|
32 |
+
|
33 |
+
|
34 |
+
# Load dataset. Extend the existing vocabulary with up to 10000 tokens from
|
35 |
+
# the training dataset.
|
36 |
+
data = load_benchmark(DATASET_PATH, vocab, extend_with=10000)
|
37 |
+
|
38 |
+
# Set up model and finetune. Note that we have to extend the embedding layer
|
39 |
+
# with the number of tokens added to the vocabulary.
|
40 |
+
#
|
41 |
+
# Also note that when using class average F1 to evaluate, the model has to be
|
42 |
+
# defined with two classes, since the model will be trained for each class
|
43 |
+
# separately.
|
44 |
+
model = torchmoji_transfer(2, PRETRAINED_PATH, extend_embedding=data['added'])
|
45 |
+
print(model)
|
46 |
+
|
47 |
+
# For finetuning however, pass in the actual number of classes.
|
48 |
+
model, f1 = class_avg_finetune(model, data['texts'], data['labels'],
|
49 |
+
nb_classes, data['batch_size'], method='last')
|
50 |
+
print('F1: {}'.format(f1))
|
examples/finetune_youtube_last.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Finetuning example.
|
2 |
+
|
3 |
+
Trains the torchMoji model on the SS-Youtube dataset, using the 'last'
|
4 |
+
finetuning method and the accuracy metric.
|
5 |
+
|
6 |
+
The 'last' method does the following:
|
7 |
+
0) Load all weights except for the softmax layer. Do not add tokens to the
|
8 |
+
vocabulary and do not extend the embedding layer.
|
9 |
+
1) Freeze all layers except for the softmax layer.
|
10 |
+
2) Train.
|
11 |
+
"""
|
12 |
+
|
13 |
+
from __future__ import print_function
|
14 |
+
import example_helper
|
15 |
+
import json
|
16 |
+
from torchmoji.model_def import torchmoji_transfer
|
17 |
+
from torchmoji.global_variables import PRETRAINED_PATH, VOCAB_PATH, ROOT_PATH
|
18 |
+
from torchmoji.finetuning import (
|
19 |
+
load_benchmark,
|
20 |
+
finetune)
|
21 |
+
|
22 |
+
DATASET_PATH = '{}/data/SS-Youtube/raw.pickle'.format(ROOT_PATH)
|
23 |
+
nb_classes = 2
|
24 |
+
|
25 |
+
with open(VOCAB_PATH, 'r') as f:
|
26 |
+
vocab = json.load(f)
|
27 |
+
|
28 |
+
# Load dataset.
|
29 |
+
data = load_benchmark(DATASET_PATH, vocab)
|
30 |
+
|
31 |
+
# Set up model and finetune
|
32 |
+
model = torchmoji_transfer(nb_classes, PRETRAINED_PATH)
|
33 |
+
print(model)
|
34 |
+
model, acc = finetune(model, data['texts'], data['labels'], nb_classes, data['batch_size'], method='last')
|
35 |
+
print('Acc: {}'.format(acc))
|
examples/score_texts_emojis.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
""" Use torchMoji to score texts for emoji distribution.
|
4 |
+
|
5 |
+
The resulting emoji ids (0-63) correspond to the mapping
|
6 |
+
in emoji_overview.png file at the root of the torchMoji repo.
|
7 |
+
|
8 |
+
Writes the result to a csv file.
|
9 |
+
"""
|
10 |
+
|
11 |
+
from __future__ import print_function, division, unicode_literals
|
12 |
+
|
13 |
+
import sys
|
14 |
+
from os.path import abspath, dirname
|
15 |
+
|
16 |
+
import json
|
17 |
+
import csv
|
18 |
+
import numpy as np
|
19 |
+
|
20 |
+
from torchmoji.sentence_tokenizer import SentenceTokenizer
|
21 |
+
from torchmoji.model_def import torchmoji_emojis
|
22 |
+
from torchmoji.global_variables import PRETRAINED_PATH, VOCAB_PATH
|
23 |
+
|
24 |
+
sys.path.insert(0, dirname(dirname(abspath(__file__))))
|
25 |
+
|
26 |
+
OUTPUT_PATH = 'test_sentences.csv'
|
27 |
+
|
28 |
+
TEST_SENTENCES = ['I love mom\'s cooking',
|
29 |
+
'I love how you never reply back..',
|
30 |
+
'I love cruising with my homies',
|
31 |
+
'I love messing with yo mind!!',
|
32 |
+
'I love you and now you\'re just gone..',
|
33 |
+
'This is shit',
|
34 |
+
'This is the shit']
|
35 |
+
|
36 |
+
|
37 |
+
def top_elements(array, k):
|
38 |
+
ind = np.argpartition(array, -k)[-k:]
|
39 |
+
return ind[np.argsort(array[ind])][::-1]
|
40 |
+
|
41 |
+
maxlen = 30
|
42 |
+
|
43 |
+
print('Tokenizing using dictionary from {}'.format(VOCAB_PATH))
|
44 |
+
with open(VOCAB_PATH, 'r') as f:
|
45 |
+
vocabulary = json.load(f)
|
46 |
+
|
47 |
+
st = SentenceTokenizer(vocabulary, maxlen)
|
48 |
+
|
49 |
+
print('Loading model from {}.'.format(PRETRAINED_PATH))
|
50 |
+
model = torchmoji_emojis(PRETRAINED_PATH)
|
51 |
+
print(model)
|
52 |
+
|
53 |
+
def doImportableFunction():
|
54 |
+
print('Running predictions.')
|
55 |
+
tokenized, _, _ = st.tokenize_sentences(TEST_SENTENCES)
|
56 |
+
prob = model(tokenized)
|
57 |
+
|
58 |
+
for prob in [prob]:
|
59 |
+
# Find top emojis for each sentence. Emoji ids (0-63)
|
60 |
+
# correspond to the mapping in emoji_overview.png
|
61 |
+
# at the root of the torchMoji repo.
|
62 |
+
print('Writing results to {}'.format(OUTPUT_PATH))
|
63 |
+
scores = []
|
64 |
+
for i, t in enumerate(TEST_SENTENCES):
|
65 |
+
t_tokens = tokenized[i]
|
66 |
+
t_score = [t]
|
67 |
+
t_prob = prob[i]
|
68 |
+
ind_top = top_elements(t_prob, 5)
|
69 |
+
t_score.append(sum(t_prob[ind_top]))
|
70 |
+
t_score.extend(ind_top)
|
71 |
+
t_score.extend([t_prob[ind] for ind in ind_top])
|
72 |
+
scores.append(t_score)
|
73 |
+
print(t_score)
|
74 |
+
|
75 |
+
with open(OUTPUT_PATH, 'w') as csvfile:
|
76 |
+
writer = csv.writer(csvfile, delimiter=str(','), lineterminator='\n')
|
77 |
+
writer.writerow(['Text', 'Top5%',
|
78 |
+
'Emoji_1', 'Emoji_2', 'Emoji_3', 'Emoji_4', 'Emoji_5',
|
79 |
+
'Pct_1', 'Pct_2', 'Pct_3', 'Pct_4', 'Pct_5'])
|
80 |
+
for i, row in enumerate(scores):
|
81 |
+
try:
|
82 |
+
writer.writerow(row)
|
83 |
+
except:
|
84 |
+
print("Exception at row {}!".format(i))
|
85 |
+
return
|
examples/text_emojize.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
""" Use torchMoji to predict emojis from a single text input
|
4 |
+
"""
|
5 |
+
|
6 |
+
from __future__ import print_function, division, unicode_literals
|
7 |
+
import example_helper
|
8 |
+
import json
|
9 |
+
import csv
|
10 |
+
import argparse
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import emoji
|
14 |
+
|
15 |
+
from torchmoji.sentence_tokenizer import SentenceTokenizer
|
16 |
+
from torchmoji.model_def import torchmoji_emojis
|
17 |
+
from torchmoji.global_variables import PRETRAINED_PATH, VOCAB_PATH
|
18 |
+
|
19 |
+
# Emoji map in emoji_overview.png
|
20 |
+
EMOJIS = ":joy: :unamused: :weary: :sob: :heart_eyes: \
|
21 |
+
:pensive: :ok_hand: :blush: :heart: :smirk: \
|
22 |
+
:grin: :notes: :flushed: :100: :sleeping: \
|
23 |
+
:relieved: :relaxed: :raised_hands: :two_hearts: :expressionless: \
|
24 |
+
:sweat_smile: :pray: :confused: :kissing_heart: :heartbeat: \
|
25 |
+
:neutral_face: :information_desk_person: :disappointed: :see_no_evil: :tired_face: \
|
26 |
+
:v: :sunglasses: :rage: :thumbsup: :cry: \
|
27 |
+
:sleepy: :yum: :triumph: :hand: :mask: \
|
28 |
+
:clap: :eyes: :gun: :persevere: :smiling_imp: \
|
29 |
+
:sweat: :broken_heart: :yellow_heart: :musical_note: :speak_no_evil: \
|
30 |
+
:wink: :skull: :confounded: :smile: :stuck_out_tongue_winking_eye: \
|
31 |
+
:angry: :no_good: :muscle: :facepunch: :purple_heart: \
|
32 |
+
:sparkling_heart: :blue_heart: :grimacing: :sparkles:".split(' ')
|
33 |
+
|
34 |
+
def top_elements(array, k):
|
35 |
+
ind = np.argpartition(array, -k)[-k:]
|
36 |
+
return ind[np.argsort(array[ind])][::-1]
|
37 |
+
|
38 |
+
if __name__ == "__main__":
|
39 |
+
argparser = argparse.ArgumentParser()
|
40 |
+
argparser.add_argument('--text', type=str, required=True, help="Input text to emojize")
|
41 |
+
argparser.add_argument('--maxlen', type=int, default=30, help="Max length of input text")
|
42 |
+
args = argparser.parse_args()
|
43 |
+
|
44 |
+
# Tokenizing using dictionary
|
45 |
+
with open(VOCAB_PATH, 'r') as f:
|
46 |
+
vocabulary = json.load(f)
|
47 |
+
|
48 |
+
st = SentenceTokenizer(vocabulary, args.maxlen)
|
49 |
+
|
50 |
+
# Loading model
|
51 |
+
model = torchmoji_emojis(PRETRAINED_PATH)
|
52 |
+
# Running predictions
|
53 |
+
tokenized, _, _ = st.tokenize_sentences([args.text])
|
54 |
+
# Get sentence probability
|
55 |
+
prob = model(tokenized)[0]
|
56 |
+
|
57 |
+
# Top emoji id
|
58 |
+
emoji_ids = top_elements(prob, 5)
|
59 |
+
|
60 |
+
# map to emojis
|
61 |
+
emojis = map(lambda x: EMOJIS[x], emoji_ids)
|
62 |
+
|
63 |
+
print(emoji.emojize("{} {}".format(args.text,' '.join(emojis)), use_aliases=True))
|
examples/tokenize_dataset.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Take a given list of sentences and turn it into a numpy array, where each
|
3 |
+
number corresponds to a word. Padding is used (number 0) to ensure fixed length
|
4 |
+
of sentences.
|
5 |
+
"""
|
6 |
+
|
7 |
+
from __future__ import print_function, unicode_literals
|
8 |
+
import example_helper
|
9 |
+
import json
|
10 |
+
from torchmoji.sentence_tokenizer import SentenceTokenizer
|
11 |
+
|
12 |
+
with open('../model/vocabulary.json', 'r') as f:
|
13 |
+
vocabulary = json.load(f)
|
14 |
+
|
15 |
+
st = SentenceTokenizer(vocabulary, 30)
|
16 |
+
test_sentences = [
|
17 |
+
'\u2014 -- \u203c !!\U0001F602',
|
18 |
+
'Hello world!',
|
19 |
+
'This is a sample tweet #example',
|
20 |
+
]
|
21 |
+
|
22 |
+
tokens, infos, stats = st.tokenize_sentences(test_sentences)
|
23 |
+
|
24 |
+
print(tokens)
|
25 |
+
print(infos)
|
26 |
+
print(stats)
|
examples/vocab_extension.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Extend the given vocabulary using dataset-specific words.
|
3 |
+
|
4 |
+
1. First create a vocabulary for the specific dataset.
|
5 |
+
2. Find all words not in our vocabulary, but in the dataset vocabulary.
|
6 |
+
3. Take top X (default=1000) of these words and add them to the vocabulary.
|
7 |
+
4. Save this combined vocabulary and embedding matrix, which can now be used.
|
8 |
+
"""
|
9 |
+
|
10 |
+
from __future__ import print_function, unicode_literals
|
11 |
+
import example_helper
|
12 |
+
import json
|
13 |
+
from torchmoji.create_vocab import extend_vocab, VocabBuilder
|
14 |
+
from torchmoji.word_generator import WordGenerator
|
15 |
+
|
16 |
+
new_words = ['#zzzzaaazzz', 'newword', 'newword']
|
17 |
+
word_gen = WordGenerator(new_words)
|
18 |
+
vb = VocabBuilder(word_gen)
|
19 |
+
vb.count_all_words()
|
20 |
+
|
21 |
+
with open('../model/vocabulary.json') as f:
|
22 |
+
vocab = json.load(f)
|
23 |
+
|
24 |
+
print(len(vocab))
|
25 |
+
print(vb.word_counts)
|
26 |
+
extend_vocab(vocab, vb, max_tokens=1)
|
27 |
+
|
28 |
+
# 'newword' should be added because it's more frequent in the given vocab
|
29 |
+
print(vocab['newword'])
|
30 |
+
print(len(vocab))
|
scripts/analyze_all_results.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
|
3 |
+
# allow us to import the codebase directory
|
4 |
+
import sys
|
5 |
+
import glob
|
6 |
+
import numpy as np
|
7 |
+
from os.path import dirname, abspath
|
8 |
+
sys.path.insert(0, dirname(dirname(abspath(__file__))))
|
9 |
+
|
10 |
+
DATASETS = ['SE0714', 'Olympic', 'PsychExp', 'SS-Twitter', 'SS-Youtube',
|
11 |
+
'SCv1', 'SV2-GEN'] # 'SE1604' excluded due to Twitter's ToS
|
12 |
+
|
13 |
+
def get_results(dset):
|
14 |
+
METHOD = 'last'
|
15 |
+
RESULTS_DIR = 'results/'
|
16 |
+
RESULT_PATHS = glob.glob('{}/{}_{}_*_results.txt'.format(RESULTS_DIR, dset, METHOD))
|
17 |
+
assert len(RESULT_PATHS)
|
18 |
+
|
19 |
+
scores = []
|
20 |
+
for path in RESULT_PATHS:
|
21 |
+
with open(path) as f:
|
22 |
+
score = f.readline().split(':')[1]
|
23 |
+
scores.append(float(score))
|
24 |
+
|
25 |
+
average = np.mean(scores)
|
26 |
+
maximum = max(scores)
|
27 |
+
minimum = min(scores)
|
28 |
+
std = np.std(scores)
|
29 |
+
|
30 |
+
print('Dataset: {}'.format(dset))
|
31 |
+
print('Method: {}'.format(METHOD))
|
32 |
+
print('Number of results: {}'.format(len(scores)))
|
33 |
+
print('--------------------------')
|
34 |
+
print('Average: {}'.format(average))
|
35 |
+
print('Maximum: {}'.format(maximum))
|
36 |
+
print('Minimum: {}'.format(minimum))
|
37 |
+
print('Standard deviaton: {}'.format(std))
|
38 |
+
|
39 |
+
for dset in DATASETS:
|
40 |
+
get_results(dset)
|
scripts/analyze_results.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
|
3 |
+
import sys
|
4 |
+
import glob
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
DATASET = 'SS-Twitter' # 'SE1604' excluded due to Twitter's ToS
|
8 |
+
METHOD = 'new'
|
9 |
+
|
10 |
+
# Optional usage: analyze_results.py <dataset> <method>
|
11 |
+
if len(sys.argv) == 3:
|
12 |
+
DATASET = sys.argv[1]
|
13 |
+
METHOD = sys.argv[2]
|
14 |
+
|
15 |
+
RESULTS_DIR = 'results/'
|
16 |
+
RESULT_PATHS = glob.glob('{}/{}_{}_*_results.txt'.format(RESULTS_DIR, DATASET, METHOD))
|
17 |
+
|
18 |
+
if not RESULT_PATHS:
|
19 |
+
print('Could not find results for \'{}\' using \'{}\' in directory \'{}\'.'.format(DATASET, METHOD, RESULTS_DIR))
|
20 |
+
else:
|
21 |
+
scores = []
|
22 |
+
for path in RESULT_PATHS:
|
23 |
+
with open(path) as f:
|
24 |
+
score = f.readline().split(':')[1]
|
25 |
+
scores.append(float(score))
|
26 |
+
|
27 |
+
average = np.mean(scores)
|
28 |
+
maximum = max(scores)
|
29 |
+
minimum = min(scores)
|
30 |
+
std = np.std(scores)
|
31 |
+
|
32 |
+
print('Dataset: {}'.format(DATASET))
|
33 |
+
print('Method: {}'.format(METHOD))
|
34 |
+
print('Number of results: {}'.format(len(scores)))
|
35 |
+
print('--------------------------')
|
36 |
+
print('Average: {}'.format(average))
|
37 |
+
print('Maximum: {}'.format(maximum))
|
38 |
+
print('Minimum: {}'.format(minimum))
|
39 |
+
print('Standard deviaton: {}'.format(std))
|
scripts/calculate_coverages.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import pickle
|
3 |
+
import json
|
4 |
+
import csv
|
5 |
+
import sys
|
6 |
+
from io import open
|
7 |
+
|
8 |
+
# Allow us to import the torchmoji directory
|
9 |
+
from os.path import dirname, abspath
|
10 |
+
sys.path.insert(0, dirname(dirname(abspath(__file__))))
|
11 |
+
|
12 |
+
from torchmoji.sentence_tokenizer import SentenceTokenizer, coverage
|
13 |
+
|
14 |
+
try:
|
15 |
+
unicode # Python 2
|
16 |
+
except NameError:
|
17 |
+
unicode = str # Python 3
|
18 |
+
|
19 |
+
IS_PYTHON2 = int(sys.version[0]) == 2
|
20 |
+
|
21 |
+
OUTPUT_PATH = 'coverage.csv'
|
22 |
+
DATASET_PATHS = [
|
23 |
+
'../data/Olympic/raw.pickle',
|
24 |
+
'../data/PsychExp/raw.pickle',
|
25 |
+
'../data/SCv1/raw.pickle',
|
26 |
+
'../data/SCv2-GEN/raw.pickle',
|
27 |
+
'../data/SE0714/raw.pickle',
|
28 |
+
#'../data/SE1604/raw.pickle', # Excluded due to Twitter's ToS
|
29 |
+
'../data/SS-Twitter/raw.pickle',
|
30 |
+
'../data/SS-Youtube/raw.pickle',
|
31 |
+
]
|
32 |
+
|
33 |
+
with open('../model/vocabulary.json', 'r') as f:
|
34 |
+
vocab = json.load(f)
|
35 |
+
|
36 |
+
results = []
|
37 |
+
for p in DATASET_PATHS:
|
38 |
+
coverage_result = [p]
|
39 |
+
print('Calculating coverage for {}'.format(p))
|
40 |
+
with open(p, 'rb') as f:
|
41 |
+
if IS_PYTHON2:
|
42 |
+
s = pickle.load(f)
|
43 |
+
else:
|
44 |
+
s = pickle.load(f, fix_imports=True)
|
45 |
+
|
46 |
+
# Decode data
|
47 |
+
try:
|
48 |
+
s['texts'] = [unicode(x) for x in s['texts']]
|
49 |
+
except UnicodeDecodeError:
|
50 |
+
s['texts'] = [x.decode('utf-8') for x in s['texts']]
|
51 |
+
|
52 |
+
# Own
|
53 |
+
st = SentenceTokenizer({}, 30)
|
54 |
+
tests, dicts, _ = st.split_train_val_test(s['texts'], s['info'],
|
55 |
+
[s['train_ind'],
|
56 |
+
s['val_ind'],
|
57 |
+
s['test_ind']],
|
58 |
+
extend_with=10000)
|
59 |
+
coverage_result.append(coverage(tests[2]))
|
60 |
+
|
61 |
+
# Last
|
62 |
+
st = SentenceTokenizer(vocab, 30)
|
63 |
+
tests, dicts, _ = st.split_train_val_test(s['texts'], s['info'],
|
64 |
+
[s['train_ind'],
|
65 |
+
s['val_ind'],
|
66 |
+
s['test_ind']],
|
67 |
+
extend_with=0)
|
68 |
+
coverage_result.append(coverage(tests[2]))
|
69 |
+
|
70 |
+
# Full
|
71 |
+
st = SentenceTokenizer(vocab, 30)
|
72 |
+
tests, dicts, _ = st.split_train_val_test(s['texts'], s['info'],
|
73 |
+
[s['train_ind'],
|
74 |
+
s['val_ind'],
|
75 |
+
s['test_ind']],
|
76 |
+
extend_with=10000)
|
77 |
+
coverage_result.append(coverage(tests[2]))
|
78 |
+
|
79 |
+
results.append(coverage_result)
|
80 |
+
|
81 |
+
with open(OUTPUT_PATH, 'wb') as csvfile:
|
82 |
+
writer = csv.writer(csvfile, delimiter='\t', lineterminator='\n')
|
83 |
+
writer.writerow(['Dataset', 'Own', 'Last', 'Full'])
|
84 |
+
for i, row in enumerate(results):
|
85 |
+
try:
|
86 |
+
writer.writerow(row)
|
87 |
+
except:
|
88 |
+
print("Exception at row {}!".format(i))
|
89 |
+
|
90 |
+
print('Saved to {}'.format(OUTPUT_PATH))
|
scripts/convert_all_datasets.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
|
3 |
+
import json
|
4 |
+
import math
|
5 |
+
import pickle
|
6 |
+
import sys
|
7 |
+
from io import open
|
8 |
+
import numpy as np
|
9 |
+
from os.path import abspath, dirname
|
10 |
+
sys.path.insert(0, dirname(dirname(abspath(__file__))))
|
11 |
+
|
12 |
+
from torchmoji.word_generator import WordGenerator
|
13 |
+
from torchmoji.create_vocab import VocabBuilder
|
14 |
+
from torchmoji.sentence_tokenizer import SentenceTokenizer, extend_vocab, coverage
|
15 |
+
from torchmoji.tokenizer import tokenize
|
16 |
+
|
17 |
+
try:
|
18 |
+
unicode # Python 2
|
19 |
+
except NameError:
|
20 |
+
unicode = str # Python 3
|
21 |
+
|
22 |
+
IS_PYTHON2 = int(sys.version[0]) == 2
|
23 |
+
|
24 |
+
DATASETS = [
|
25 |
+
'Olympic',
|
26 |
+
'PsychExp',
|
27 |
+
'SCv1',
|
28 |
+
'SCv2-GEN',
|
29 |
+
'SE0714',
|
30 |
+
#'SE1604', # Excluded due to Twitter's ToS
|
31 |
+
'SS-Twitter',
|
32 |
+
'SS-Youtube',
|
33 |
+
]
|
34 |
+
|
35 |
+
DIR = '../data'
|
36 |
+
FILENAME_RAW = 'raw.pickle'
|
37 |
+
FILENAME_OWN = 'own_vocab.pickle'
|
38 |
+
FILENAME_OUR = 'twitter_vocab.pickle'
|
39 |
+
FILENAME_COMBINED = 'combined_vocab.pickle'
|
40 |
+
|
41 |
+
|
42 |
+
def roundup(x):
|
43 |
+
return int(math.ceil(x / 10.0)) * 10
|
44 |
+
|
45 |
+
|
46 |
+
def format_pickle(dset, train_texts, val_texts, test_texts, train_labels, val_labels, test_labels):
|
47 |
+
return {'dataset': dset,
|
48 |
+
'train_texts': train_texts,
|
49 |
+
'val_texts': val_texts,
|
50 |
+
'test_texts': test_texts,
|
51 |
+
'train_labels': train_labels,
|
52 |
+
'val_labels': val_labels,
|
53 |
+
'test_labels': test_labels}
|
54 |
+
|
55 |
+
def convert_dataset(filepath, extend_with, vocab):
|
56 |
+
print('-- Generating {} '.format(filepath))
|
57 |
+
sys.stdout.flush()
|
58 |
+
st = SentenceTokenizer(vocab, maxlen)
|
59 |
+
tokenized, dicts, _ = st.split_train_val_test(texts,
|
60 |
+
labels,
|
61 |
+
[data['train_ind'],
|
62 |
+
data['val_ind'],
|
63 |
+
data['test_ind']],
|
64 |
+
extend_with=extend_with)
|
65 |
+
pick = format_pickle(dset, tokenized[0], tokenized[1], tokenized[2],
|
66 |
+
dicts[0], dicts[1], dicts[2])
|
67 |
+
with open(filepath, 'w') as f:
|
68 |
+
pickle.dump(pick, f)
|
69 |
+
cover = coverage(tokenized[2])
|
70 |
+
|
71 |
+
print(' done. Coverage: {}'.format(cover))
|
72 |
+
|
73 |
+
with open('../model/vocabulary.json', 'r') as f:
|
74 |
+
vocab = json.load(f)
|
75 |
+
|
76 |
+
for dset in DATASETS:
|
77 |
+
print('Converting {}'.format(dset))
|
78 |
+
|
79 |
+
PATH_RAW = '{}/{}/{}'.format(DIR, dset, FILENAME_RAW)
|
80 |
+
PATH_OWN = '{}/{}/{}'.format(DIR, dset, FILENAME_OWN)
|
81 |
+
PATH_OUR = '{}/{}/{}'.format(DIR, dset, FILENAME_OUR)
|
82 |
+
PATH_COMBINED = '{}/{}/{}'.format(DIR, dset, FILENAME_COMBINED)
|
83 |
+
|
84 |
+
with open(PATH_RAW, 'rb') as dataset:
|
85 |
+
if IS_PYTHON2:
|
86 |
+
data = pickle.load(dataset)
|
87 |
+
else:
|
88 |
+
data = pickle.load(dataset, fix_imports=True)
|
89 |
+
|
90 |
+
# Decode data
|
91 |
+
try:
|
92 |
+
texts = [unicode(x) for x in data['texts']]
|
93 |
+
except UnicodeDecodeError:
|
94 |
+
texts = [x.decode('utf-8') for x in data['texts']]
|
95 |
+
|
96 |
+
wg = WordGenerator(texts)
|
97 |
+
vb = VocabBuilder(wg)
|
98 |
+
vb.count_all_words()
|
99 |
+
|
100 |
+
# Calculate max length of sequences considered
|
101 |
+
# Adjust batch_size accordingly to prevent GPU overflow
|
102 |
+
lengths = [len(tokenize(t)) for t in texts]
|
103 |
+
maxlen = roundup(np.percentile(lengths, 80.0))
|
104 |
+
|
105 |
+
# Extract labels
|
106 |
+
labels = [x['label'] for x in data['info']]
|
107 |
+
|
108 |
+
convert_dataset(PATH_OWN, 50000, {})
|
109 |
+
convert_dataset(PATH_OUR, 0, vocab)
|
110 |
+
convert_dataset(PATH_COMBINED, 10000, vocab)
|
scripts/download_weights.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import os
|
3 |
+
from subprocess import call
|
4 |
+
from builtins import input
|
5 |
+
|
6 |
+
curr_folder = os.path.basename(os.path.normpath(os.getcwd()))
|
7 |
+
|
8 |
+
weights_filename = 'pytorch_model.bin'
|
9 |
+
weights_folder = 'model'
|
10 |
+
weights_path = '{}/{}'.format(weights_folder, weights_filename)
|
11 |
+
if curr_folder == 'scripts':
|
12 |
+
weights_path = '../' + weights_path
|
13 |
+
weights_download_link = 'https://www.dropbox.com/s/q8lax9ary32c7t9/pytorch_model.bin?dl=0#'
|
14 |
+
|
15 |
+
|
16 |
+
MB_FACTOR = float(1<<20)
|
17 |
+
|
18 |
+
def prompt():
|
19 |
+
while True:
|
20 |
+
valid = {
|
21 |
+
'y': True,
|
22 |
+
'ye': True,
|
23 |
+
'yes': True,
|
24 |
+
'n': False,
|
25 |
+
'no': False,
|
26 |
+
}
|
27 |
+
choice = input().lower()
|
28 |
+
if choice in valid:
|
29 |
+
return valid[choice]
|
30 |
+
else:
|
31 |
+
print('Please respond with \'y\' or \'n\' (or \'yes\' or \'no\')')
|
32 |
+
|
33 |
+
download = True
|
34 |
+
if os.path.exists(weights_path):
|
35 |
+
print('Weight file already exists at {}. Would you like to redownload it anyway? [y/n]'.format(weights_path))
|
36 |
+
download = prompt()
|
37 |
+
already_exists = True
|
38 |
+
else:
|
39 |
+
already_exists = False
|
40 |
+
|
41 |
+
if download:
|
42 |
+
print('About to download the pretrained weights file from {}'.format(weights_download_link))
|
43 |
+
if already_exists == False:
|
44 |
+
print('The size of the file is roughly 85MB. Continue? [y/n]')
|
45 |
+
else:
|
46 |
+
os.unlink(weights_path)
|
47 |
+
|
48 |
+
if already_exists or prompt():
|
49 |
+
print('Downloading...')
|
50 |
+
|
51 |
+
#urllib.urlretrieve(weights_download_link, weights_path)
|
52 |
+
#with open(weights_path,'wb') as f:
|
53 |
+
# f.write(requests.get(weights_download_link).content)
|
54 |
+
|
55 |
+
# downloading using wget due to issues with urlretrieve and requests
|
56 |
+
sys_call = 'wget {} -O {}'.format(weights_download_link, os.path.abspath(weights_path))
|
57 |
+
print("Running system call: {}".format(sys_call))
|
58 |
+
call(sys_call, shell=True)
|
59 |
+
|
60 |
+
if os.path.getsize(weights_path) / MB_FACTOR < 80:
|
61 |
+
raise ValueError("Download finished, but the resulting file is too small! " +
|
62 |
+
"It\'s only {} bytes.".format(os.path.getsize(weights_path)))
|
63 |
+
print('Downloaded weights to {}'.format(weights_path))
|
64 |
+
else:
|
65 |
+
print('Exiting.')
|
scripts/finetune_dataset.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Finetuning example.
|
2 |
+
"""
|
3 |
+
from __future__ import print_function
|
4 |
+
import sys
|
5 |
+
import numpy as np
|
6 |
+
from os.path import abspath, dirname
|
7 |
+
sys.path.insert(0, dirname(dirname(abspath(__file__))))
|
8 |
+
|
9 |
+
import json
|
10 |
+
import math
|
11 |
+
from torchmoji.model_def import torchmoji_transfer
|
12 |
+
from torchmoji.global_variables import PRETRAINED_PATH, VOCAB_PATH
|
13 |
+
from torchmoji.finetuning import (
|
14 |
+
load_benchmark,
|
15 |
+
finetune)
|
16 |
+
from torchmoji.class_avg_finetuning import class_avg_finetune
|
17 |
+
|
18 |
+
def roundup(x):
|
19 |
+
return int(math.ceil(x / 10.0)) * 10
|
20 |
+
|
21 |
+
|
22 |
+
# Format: (dataset_name,
|
23 |
+
# path_to_dataset,
|
24 |
+
# nb_classes,
|
25 |
+
# use_f1_score)
|
26 |
+
DATASETS = [
|
27 |
+
#('SE0714', '../data/SE0714/raw.pickle', 3, True),
|
28 |
+
#('Olympic', '../data/Olympic/raw.pickle', 4, True),
|
29 |
+
#('PsychExp', '../data/PsychExp/raw.pickle', 7, True),
|
30 |
+
#('SS-Twitter', '../data/SS-Twitter/raw.pickle', 2, False),
|
31 |
+
('SS-Youtube', '../data/SS-Youtube/raw.pickle', 2, False),
|
32 |
+
#('SE1604', '../data/SE1604/raw.pickle', 3, False), # Excluded due to Twitter's ToS
|
33 |
+
#('SCv1', '../data/SCv1/raw.pickle', 2, True),
|
34 |
+
#('SCv2-GEN', '../data/SCv2-GEN/raw.pickle', 2, True)
|
35 |
+
]
|
36 |
+
|
37 |
+
RESULTS_DIR = 'results'
|
38 |
+
|
39 |
+
# 'new' | 'last' | 'full' | 'chain-thaw'
|
40 |
+
FINETUNE_METHOD = 'last'
|
41 |
+
VERBOSE = 1
|
42 |
+
|
43 |
+
nb_tokens = 50000
|
44 |
+
nb_epochs = 1000
|
45 |
+
epoch_size = 1000
|
46 |
+
|
47 |
+
with open(VOCAB_PATH, 'r') as f:
|
48 |
+
vocab = json.load(f)
|
49 |
+
|
50 |
+
for rerun_iter in range(5):
|
51 |
+
for p in DATASETS:
|
52 |
+
|
53 |
+
# debugging
|
54 |
+
assert len(vocab) == nb_tokens
|
55 |
+
|
56 |
+
dset = p[0]
|
57 |
+
path = p[1]
|
58 |
+
nb_classes = p[2]
|
59 |
+
use_f1_score = p[3]
|
60 |
+
|
61 |
+
if FINETUNE_METHOD == 'last':
|
62 |
+
extend_with = 0
|
63 |
+
elif FINETUNE_METHOD in ['new', 'full', 'chain-thaw']:
|
64 |
+
extend_with = 10000
|
65 |
+
else:
|
66 |
+
raise ValueError('Finetuning method not recognised!')
|
67 |
+
|
68 |
+
# Load dataset.
|
69 |
+
data = load_benchmark(path, vocab, extend_with=extend_with)
|
70 |
+
|
71 |
+
(X_train, y_train) = (data['texts'][0], data['labels'][0])
|
72 |
+
(X_val, y_val) = (data['texts'][1], data['labels'][1])
|
73 |
+
(X_test, y_test) = (data['texts'][2], data['labels'][2])
|
74 |
+
|
75 |
+
weight_path = PRETRAINED_PATH if FINETUNE_METHOD != 'new' else None
|
76 |
+
nb_model_classes = 2 if use_f1_score else nb_classes
|
77 |
+
model = torchmoji_transfer(
|
78 |
+
nb_model_classes,
|
79 |
+
weight_path,
|
80 |
+
extend_embedding=data['added'])
|
81 |
+
print(model)
|
82 |
+
|
83 |
+
# Training
|
84 |
+
print('Training: {}'.format(path))
|
85 |
+
if use_f1_score:
|
86 |
+
model, result = class_avg_finetune(model, data['texts'],
|
87 |
+
data['labels'],
|
88 |
+
nb_classes, data['batch_size'],
|
89 |
+
FINETUNE_METHOD,
|
90 |
+
verbose=VERBOSE)
|
91 |
+
else:
|
92 |
+
model, result = finetune(model, data['texts'], data['labels'],
|
93 |
+
nb_classes, data['batch_size'],
|
94 |
+
FINETUNE_METHOD, metric='acc',
|
95 |
+
verbose=VERBOSE)
|
96 |
+
|
97 |
+
# Write results
|
98 |
+
if use_f1_score:
|
99 |
+
print('Overall F1 score (dset = {}): {}'.format(dset, result))
|
100 |
+
with open('{}/{}_{}_{}_results.txt'.
|
101 |
+
format(RESULTS_DIR, dset, FINETUNE_METHOD, rerun_iter),
|
102 |
+
"w") as f:
|
103 |
+
f.write("F1: {}\n".format(result))
|
104 |
+
else:
|
105 |
+
print('Test accuracy (dset = {}): {}'.format(dset, result))
|
106 |
+
with open('{}/{}_{}_{}_results.txt'.
|
107 |
+
format(RESULTS_DIR, dset, FINETUNE_METHOD, rerun_iter),
|
108 |
+
"w") as f:
|
109 |
+
f.write("Acc: {}\n".format(result))
|
scripts/results/.gitkeep
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
setup.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup
|
2 |
+
|
3 |
+
setup(
|
4 |
+
name='torchmoji',
|
5 |
+
version='1.0',
|
6 |
+
packages=['torchmoji'],
|
7 |
+
description='torchMoji',
|
8 |
+
include_package_data=True,
|
9 |
+
install_requires=[
|
10 |
+
'emoji==0.4.5',
|
11 |
+
'numpy==1.13.1',
|
12 |
+
'scipy==0.19.1',
|
13 |
+
'scikit-learn==0.19.0',
|
14 |
+
'text-unidecode==1.0',
|
15 |
+
],
|
16 |
+
)
|
tests/test_finetuning.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, print_function, division, unicode_literals
|
2 |
+
|
3 |
+
import test_helper
|
4 |
+
|
5 |
+
from nose.plugins.attrib import attr
|
6 |
+
import json
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from torchmoji.class_avg_finetuning import relabel
|
10 |
+
from torchmoji.sentence_tokenizer import SentenceTokenizer
|
11 |
+
|
12 |
+
from torchmoji.finetuning import (
|
13 |
+
calculate_batchsize_maxlen,
|
14 |
+
freeze_layers,
|
15 |
+
change_trainable,
|
16 |
+
finetune,
|
17 |
+
load_benchmark
|
18 |
+
)
|
19 |
+
from torchmoji.model_def import (
|
20 |
+
torchmoji_transfer,
|
21 |
+
torchmoji_feature_encoding,
|
22 |
+
torchmoji_emojis
|
23 |
+
)
|
24 |
+
from torchmoji.global_variables import (
|
25 |
+
PRETRAINED_PATH,
|
26 |
+
NB_TOKENS,
|
27 |
+
VOCAB_PATH,
|
28 |
+
ROOT_PATH
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def test_calculate_batchsize_maxlen():
|
33 |
+
""" Batch size and max length are calculated properly.
|
34 |
+
"""
|
35 |
+
texts = ['a b c d',
|
36 |
+
'e f g h i']
|
37 |
+
batch_size, maxlen = calculate_batchsize_maxlen(texts)
|
38 |
+
|
39 |
+
assert batch_size == 250
|
40 |
+
assert maxlen == 10, maxlen
|
41 |
+
|
42 |
+
|
43 |
+
def test_freeze_layers():
|
44 |
+
""" Correct layers are frozen.
|
45 |
+
"""
|
46 |
+
model = torchmoji_transfer(5)
|
47 |
+
keyword = 'output_layer'
|
48 |
+
|
49 |
+
model = freeze_layers(model, unfrozen_keyword=keyword)
|
50 |
+
|
51 |
+
for name, module in model.named_children():
|
52 |
+
trainable = keyword.lower() in name.lower()
|
53 |
+
assert all(p.requires_grad == trainable for p in module.parameters())
|
54 |
+
|
55 |
+
|
56 |
+
def test_change_trainable():
|
57 |
+
""" change_trainable() changes trainability of layers.
|
58 |
+
"""
|
59 |
+
model = torchmoji_transfer(5)
|
60 |
+
change_trainable(model.embed, False)
|
61 |
+
assert not any(p.requires_grad for p in model.embed.parameters())
|
62 |
+
change_trainable(model.embed, True)
|
63 |
+
assert all(p.requires_grad for p in model.embed.parameters())
|
64 |
+
|
65 |
+
|
66 |
+
def test_torchmoji_transfer_extend_embedding():
|
67 |
+
""" Defining torchmoji with extension.
|
68 |
+
"""
|
69 |
+
extend_with = 50
|
70 |
+
model = torchmoji_transfer(5, weight_path=PRETRAINED_PATH,
|
71 |
+
extend_embedding=extend_with)
|
72 |
+
embedding_layer = model.embed
|
73 |
+
assert embedding_layer.weight.size()[0] == NB_TOKENS + extend_with
|
74 |
+
|
75 |
+
|
76 |
+
def test_torchmoji_return_attention():
|
77 |
+
seq_tensor = np.array([[1]])
|
78 |
+
# test the output of the normal model
|
79 |
+
model = torchmoji_emojis(weight_path=PRETRAINED_PATH)
|
80 |
+
# check correct number of outputs
|
81 |
+
assert len(model(seq_tensor)) == 1
|
82 |
+
# repeat above described tests when returning attention weights
|
83 |
+
model = torchmoji_emojis(weight_path=PRETRAINED_PATH, return_attention=True)
|
84 |
+
assert len(model(seq_tensor)) == 2
|
85 |
+
|
86 |
+
|
87 |
+
def test_relabel():
|
88 |
+
""" relabel() works with multi-class labels.
|
89 |
+
"""
|
90 |
+
nb_classes = 3
|
91 |
+
inputs = np.array([
|
92 |
+
[True, False, False],
|
93 |
+
[False, True, False],
|
94 |
+
[True, False, True],
|
95 |
+
])
|
96 |
+
expected_0 = np.array([True, False, True])
|
97 |
+
expected_1 = np.array([False, True, False])
|
98 |
+
expected_2 = np.array([False, False, True])
|
99 |
+
|
100 |
+
assert np.array_equal(relabel(inputs, 0, nb_classes), expected_0)
|
101 |
+
assert np.array_equal(relabel(inputs, 1, nb_classes), expected_1)
|
102 |
+
assert np.array_equal(relabel(inputs, 2, nb_classes), expected_2)
|
103 |
+
|
104 |
+
|
105 |
+
def test_relabel_binary():
|
106 |
+
""" relabel() works with binary classification (no changes to labels)
|
107 |
+
"""
|
108 |
+
nb_classes = 2
|
109 |
+
inputs = np.array([True, False, False])
|
110 |
+
|
111 |
+
assert np.array_equal(relabel(inputs, 0, nb_classes), inputs)
|
112 |
+
|
113 |
+
|
114 |
+
@attr('slow')
|
115 |
+
def test_finetune_full():
|
116 |
+
""" finetuning using 'full'.
|
117 |
+
"""
|
118 |
+
DATASET_PATH = ROOT_PATH+'/data/SS-Youtube/raw.pickle'
|
119 |
+
nb_classes = 2
|
120 |
+
# Keras and pyTorch implementation of the Adam optimizer are slightly different and change a bit the results
|
121 |
+
# We reduce the min accuracy needed here to pass the test
|
122 |
+
# See e.g. https://discuss.pytorch.org/t/suboptimal-convergence-when-compared-with-tensorflow-model/5099/11
|
123 |
+
min_acc = 0.68
|
124 |
+
|
125 |
+
with open(VOCAB_PATH, 'r') as f:
|
126 |
+
vocab = json.load(f)
|
127 |
+
|
128 |
+
data = load_benchmark(DATASET_PATH, vocab, extend_with=10000)
|
129 |
+
print('Loading pyTorch model from {}.'.format(PRETRAINED_PATH))
|
130 |
+
model = torchmoji_transfer(nb_classes, PRETRAINED_PATH, extend_embedding=data['added'])
|
131 |
+
print(model)
|
132 |
+
model, acc = finetune(model, data['texts'], data['labels'], nb_classes,
|
133 |
+
data['batch_size'], method='full', nb_epochs=1)
|
134 |
+
|
135 |
+
print("Finetune full SS-Youtube 1 epoch acc: {}".format(acc))
|
136 |
+
assert acc >= min_acc
|
137 |
+
|
138 |
+
|
139 |
+
@attr('slow')
|
140 |
+
def test_finetune_last():
|
141 |
+
""" finetuning using 'last'.
|
142 |
+
"""
|
143 |
+
dataset_path = ROOT_PATH + '/data/SS-Youtube/raw.pickle'
|
144 |
+
nb_classes = 2
|
145 |
+
min_acc = 0.68
|
146 |
+
|
147 |
+
with open(VOCAB_PATH, 'r') as f:
|
148 |
+
vocab = json.load(f)
|
149 |
+
|
150 |
+
data = load_benchmark(dataset_path, vocab)
|
151 |
+
print('Loading model from {}.'.format(PRETRAINED_PATH))
|
152 |
+
model = torchmoji_transfer(nb_classes, PRETRAINED_PATH)
|
153 |
+
print(model)
|
154 |
+
model, acc = finetune(model, data['texts'], data['labels'], nb_classes,
|
155 |
+
data['batch_size'], method='last', nb_epochs=1)
|
156 |
+
|
157 |
+
print("Finetune last SS-Youtube 1 epoch acc: {}".format(acc))
|
158 |
+
|
159 |
+
assert acc >= min_acc
|
160 |
+
|
161 |
+
|
162 |
+
def test_score_emoji():
|
163 |
+
""" Emoji predictions make sense.
|
164 |
+
"""
|
165 |
+
test_sentences = [
|
166 |
+
'I love mom\'s cooking',
|
167 |
+
'I love how you never reply back..',
|
168 |
+
'I love cruising with my homies',
|
169 |
+
'I love messing with yo mind!!',
|
170 |
+
'I love you and now you\'re just gone..',
|
171 |
+
'This is shit',
|
172 |
+
'This is the shit'
|
173 |
+
]
|
174 |
+
|
175 |
+
expected = [
|
176 |
+
np.array([36, 4, 8, 16, 47]),
|
177 |
+
np.array([1, 19, 55, 25, 46]),
|
178 |
+
np.array([31, 6, 30, 15, 13]),
|
179 |
+
np.array([54, 44, 9, 50, 49]),
|
180 |
+
np.array([46, 5, 27, 35, 34]),
|
181 |
+
np.array([55, 32, 27, 1, 37]),
|
182 |
+
np.array([48, 11, 6, 31, 9])
|
183 |
+
]
|
184 |
+
|
185 |
+
def top_elements(array, k):
|
186 |
+
ind = np.argpartition(array, -k)[-k:]
|
187 |
+
return ind[np.argsort(array[ind])][::-1]
|
188 |
+
|
189 |
+
# Initialize by loading dictionary and tokenize texts
|
190 |
+
with open(VOCAB_PATH, 'r') as f:
|
191 |
+
vocabulary = json.load(f)
|
192 |
+
|
193 |
+
st = SentenceTokenizer(vocabulary, 30)
|
194 |
+
tokens, _, _ = st.tokenize_sentences(test_sentences)
|
195 |
+
|
196 |
+
# Load model and run
|
197 |
+
model = torchmoji_emojis(weight_path=PRETRAINED_PATH)
|
198 |
+
prob = model(tokens)
|
199 |
+
|
200 |
+
# Find top emojis for each sentence
|
201 |
+
for i, t_prob in enumerate(list(prob)):
|
202 |
+
assert np.array_equal(top_elements(t_prob, 5), expected[i])
|
203 |
+
|
204 |
+
|
205 |
+
def test_encode_texts():
|
206 |
+
""" Text encoding is stable.
|
207 |
+
"""
|
208 |
+
|
209 |
+
TEST_SENTENCES = ['I love mom\'s cooking',
|
210 |
+
'I love how you never reply back..',
|
211 |
+
'I love cruising with my homies',
|
212 |
+
'I love messing with yo mind!!',
|
213 |
+
'I love you and now you\'re just gone..',
|
214 |
+
'This is shit',
|
215 |
+
'This is the shit']
|
216 |
+
|
217 |
+
|
218 |
+
maxlen = 30
|
219 |
+
batch_size = 32
|
220 |
+
|
221 |
+
with open(VOCAB_PATH, 'r') as f:
|
222 |
+
vocabulary = json.load(f)
|
223 |
+
|
224 |
+
st = SentenceTokenizer(vocabulary, maxlen)
|
225 |
+
|
226 |
+
print('Loading model from {}.'.format(PRETRAINED_PATH))
|
227 |
+
model = torchmoji_feature_encoding(PRETRAINED_PATH)
|
228 |
+
print(model)
|
229 |
+
tokenized, _, _ = st.tokenize_sentences(TEST_SENTENCES)
|
230 |
+
encoding = model(tokenized)
|
231 |
+
|
232 |
+
avg_across_sentences = np.around(np.mean(encoding, axis=0)[:5], 3)
|
233 |
+
assert np.allclose(avg_across_sentences, np.array([-0.023, 0.021, -0.037, -0.001, -0.005]))
|
234 |
+
|
235 |
+
test_encode_texts()
|
tests/test_helper.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Module import helper.
|
2 |
+
Modifies PATH in order to allow us to import the torchmoji directory.
|
3 |
+
"""
|
4 |
+
import sys
|
5 |
+
from os.path import abspath, dirname
|
6 |
+
sys.path.insert(0, dirname(dirname(abspath(__file__))))
|
tests/test_sentence_tokenizer.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, print_function, division, unicode_literals
|
2 |
+
import test_helper
|
3 |
+
import json
|
4 |
+
|
5 |
+
from torchmoji.sentence_tokenizer import SentenceTokenizer
|
6 |
+
|
7 |
+
sentences = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J']
|
8 |
+
|
9 |
+
dicts = [
|
10 |
+
{'label': 0},
|
11 |
+
{'label': 1},
|
12 |
+
{'label': 2},
|
13 |
+
{'label': 3},
|
14 |
+
{'label': 4},
|
15 |
+
{'label': 5},
|
16 |
+
{'label': 6},
|
17 |
+
{'label': 7},
|
18 |
+
{'label': 8},
|
19 |
+
{'label': 9},
|
20 |
+
]
|
21 |
+
|
22 |
+
train_ind = [0, 5, 3, 6, 8]
|
23 |
+
val_ind = [9, 2, 1]
|
24 |
+
test_ind = [4, 7]
|
25 |
+
|
26 |
+
with open('../model/vocabulary.json', 'r') as f:
|
27 |
+
vocab = json.load(f)
|
28 |
+
|
29 |
+
def test_dataset_split_parameter():
|
30 |
+
""" Dataset is split in the desired ratios
|
31 |
+
"""
|
32 |
+
split_parameter = [0.7, 0.1, 0.2]
|
33 |
+
st = SentenceTokenizer(vocab, 30)
|
34 |
+
|
35 |
+
result, result_dicts, _ = st.split_train_val_test(sentences, dicts,
|
36 |
+
split_parameter, extend_with=0)
|
37 |
+
train = result[0]
|
38 |
+
val = result[1]
|
39 |
+
test = result[2]
|
40 |
+
|
41 |
+
train_dicts = result_dicts[0]
|
42 |
+
val_dicts = result_dicts[1]
|
43 |
+
test_dicts = result_dicts[2]
|
44 |
+
|
45 |
+
assert len(train) == len(sentences) * split_parameter[0]
|
46 |
+
assert len(val) == len(sentences) * split_parameter[1]
|
47 |
+
assert len(test) == len(sentences) * split_parameter[2]
|
48 |
+
|
49 |
+
assert len(train_dicts) == len(dicts) * split_parameter[0]
|
50 |
+
assert len(val_dicts) == len(dicts) * split_parameter[1]
|
51 |
+
assert len(test_dicts) == len(dicts) * split_parameter[2]
|
52 |
+
|
53 |
+
def test_dataset_split_explicit():
|
54 |
+
""" Dataset is split according to given indices
|
55 |
+
"""
|
56 |
+
split_parameter = [train_ind, val_ind, test_ind]
|
57 |
+
st = SentenceTokenizer(vocab, 30)
|
58 |
+
tokenized, _, _ = st.tokenize_sentences(sentences)
|
59 |
+
|
60 |
+
result, result_dicts, added = st.split_train_val_test(sentences, dicts, split_parameter, extend_with=0)
|
61 |
+
train = result[0]
|
62 |
+
val = result[1]
|
63 |
+
test = result[2]
|
64 |
+
|
65 |
+
train_dicts = result_dicts[0]
|
66 |
+
val_dicts = result_dicts[1]
|
67 |
+
test_dicts = result_dicts[2]
|
68 |
+
|
69 |
+
tokenized = tokenized
|
70 |
+
|
71 |
+
for i, sentence in enumerate(sentences):
|
72 |
+
if i in train_ind:
|
73 |
+
assert tokenized[i] in train
|
74 |
+
assert dicts[i] in train_dicts
|
75 |
+
elif i in val_ind:
|
76 |
+
assert tokenized[i] in val
|
77 |
+
assert dicts[i] in val_dicts
|
78 |
+
elif i in test_ind:
|
79 |
+
assert tokenized[i] in test
|
80 |
+
assert dicts[i] in test_dicts
|
81 |
+
|
82 |
+
assert len(train) == len(train_ind)
|
83 |
+
assert len(val) == len(val_ind)
|
84 |
+
assert len(test) == len(test_ind)
|
85 |
+
assert len(train_dicts) == len(train_ind)
|
86 |
+
assert len(val_dicts) == len(val_ind)
|
87 |
+
assert len(test_dicts) == len(test_ind)
|
88 |
+
|
89 |
+
def test_id_to_sentence():
|
90 |
+
"""Tokenizing and converting back preserves the input.
|
91 |
+
"""
|
92 |
+
vb = {'CUSTOM_MASK': 0,
|
93 |
+
'aasdf': 1000,
|
94 |
+
'basdf': 2000}
|
95 |
+
|
96 |
+
sentence = 'aasdf basdf basdf basdf'
|
97 |
+
st = SentenceTokenizer(vb, 30)
|
98 |
+
token, _, _ = st.tokenize_sentences([sentence])
|
99 |
+
assert st.to_sentence(token[0]) == sentence
|
100 |
+
|
101 |
+
def test_id_to_sentence_with_unknown():
|
102 |
+
"""Tokenizing and converting back preserves the input, except for unknowns.
|
103 |
+
"""
|
104 |
+
vb = {'CUSTOM_MASK': 0,
|
105 |
+
'CUSTOM_UNKNOWN': 1,
|
106 |
+
'aasdf': 1000,
|
107 |
+
'basdf': 2000}
|
108 |
+
|
109 |
+
sentence = 'aasdf basdf ccc'
|
110 |
+
expected = 'aasdf basdf CUSTOM_UNKNOWN'
|
111 |
+
st = SentenceTokenizer(vb, 30)
|
112 |
+
token, _, _ = st.tokenize_sentences([sentence])
|
113 |
+
assert st.to_sentence(token[0]) == expected
|
tests/test_tokenizer.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
""" Tokenization tests.
|
3 |
+
"""
|
4 |
+
from __future__ import absolute_import, print_function, division, unicode_literals
|
5 |
+
|
6 |
+
import sys
|
7 |
+
from nose.tools import nottest
|
8 |
+
from os.path import dirname, abspath
|
9 |
+
sys.path.append(dirname(dirname(abspath(__file__))))
|
10 |
+
from torchmoji.tokenizer import tokenize
|
11 |
+
|
12 |
+
TESTS_NORMAL = [
|
13 |
+
('200K words!', ['200', 'K', 'words', '!']),
|
14 |
+
]
|
15 |
+
|
16 |
+
TESTS_EMOJIS = [
|
17 |
+
('i \U0001f496 you to the moon and back',
|
18 |
+
['i', '\U0001f496', 'you', 'to', 'the', 'moon', 'and', 'back']),
|
19 |
+
("i\U0001f496you to the \u2605's and back",
|
20 |
+
['i', '\U0001f496', 'you', 'to', 'the',
|
21 |
+
'\u2605', "'", 's', 'and', 'back']),
|
22 |
+
('~<3~', ['~', '<3', '~']),
|
23 |
+
('<333', ['<333']),
|
24 |
+
(':-)', [':-)']),
|
25 |
+
('>:-(', ['>:-(']),
|
26 |
+
('\u266b\u266a\u2605\u2606\u2665\u2764\u2661',
|
27 |
+
['\u266b', '\u266a', '\u2605', '\u2606',
|
28 |
+
'\u2665', '\u2764', '\u2661']),
|
29 |
+
]
|
30 |
+
|
31 |
+
TESTS_URLS = [
|
32 |
+
('www.sample.com', ['www.sample.com']),
|
33 |
+
('http://endless.horse', ['http://endless.horse']),
|
34 |
+
('https://github.mit.ed', ['https://github.mit.ed']),
|
35 |
+
]
|
36 |
+
|
37 |
+
TESTS_TWITTER = [
|
38 |
+
('#blacklivesmatter', ['#blacklivesmatter']),
|
39 |
+
('#99_percent.', ['#99_percent', '.']),
|
40 |
+
('the#99%', ['the', '#99', '%']),
|
41 |
+
('@golden_zenith', ['@golden_zenith']),
|
42 |
+
('@99_percent', ['@99_percent']),
|
43 |
+
('[email protected]', ['[email protected]']),
|
44 |
+
]
|
45 |
+
|
46 |
+
TESTS_PHONE_NUMS = [
|
47 |
+
('518)528-0252', ['518', ')', '528', '-', '0252']),
|
48 |
+
('1200-0221-0234', ['1200', '-', '0221', '-', '0234']),
|
49 |
+
('1200.0221.0234', ['1200', '.', '0221', '.', '0234']),
|
50 |
+
]
|
51 |
+
|
52 |
+
TESTS_DATETIME = [
|
53 |
+
('15:00', ['15', ':', '00']),
|
54 |
+
('2:00pm', ['2', ':', '00', 'pm']),
|
55 |
+
('9/14/16', ['9', '/', '14', '/', '16']),
|
56 |
+
]
|
57 |
+
|
58 |
+
TESTS_CURRENCIES = [
|
59 |
+
('517.933\xa3', ['517', '.', '933', '\xa3']),
|
60 |
+
('$517.87', ['$', '517', '.', '87']),
|
61 |
+
('1201.6598', ['1201', '.', '6598']),
|
62 |
+
('120,6', ['120', ',', '6']),
|
63 |
+
('10,00\u20ac', ['10', ',', '00', '\u20ac']),
|
64 |
+
('1,000', ['1', ',', '000']),
|
65 |
+
('1200pesos', ['1200', 'pesos']),
|
66 |
+
]
|
67 |
+
|
68 |
+
TESTS_NUM_SYM = [
|
69 |
+
('5162f', ['5162', 'f']),
|
70 |
+
('f5162', ['f', '5162']),
|
71 |
+
('1203(', ['1203', '(']),
|
72 |
+
('(1203)', ['(', '1203', ')']),
|
73 |
+
('1200/', ['1200', '/']),
|
74 |
+
('1200+', ['1200', '+']),
|
75 |
+
('1202o-east', ['1202', 'o-east']),
|
76 |
+
('1200r', ['1200', 'r']),
|
77 |
+
('1200-1400', ['1200', '-', '1400']),
|
78 |
+
('120/today', ['120', '/', 'today']),
|
79 |
+
('today/120', ['today', '/', '120']),
|
80 |
+
('120/5', ['120', '/', '5']),
|
81 |
+
("120'/5", ['120', "'", '/', '5']),
|
82 |
+
('120/5pro', ['120', '/', '5', 'pro']),
|
83 |
+
("1200's,)", ['1200', "'", 's', ',', ')']),
|
84 |
+
('120.76.218.207', ['120', '.', '76', '.', '218', '.', '207']),
|
85 |
+
]
|
86 |
+
|
87 |
+
TESTS_PUNCTUATION = [
|
88 |
+
("don''t", ['don', "''", 't']),
|
89 |
+
("don'tcha", ["don'tcha"]),
|
90 |
+
('no?!?!;', ['no', '?', '!', '?', '!', ';']),
|
91 |
+
('no??!!..', ['no', '??', '!!', '..']),
|
92 |
+
('a.m.', ['a.m.']),
|
93 |
+
('.s.u', ['.', 's', '.', 'u']),
|
94 |
+
('!!i..n__', ['!!', 'i', '..', 'n', '__']),
|
95 |
+
('lv(<3)w(3>)u Mr.!', ['lv', '(', '<3', ')', 'w', '(', '3',
|
96 |
+
'>', ')', 'u', 'Mr.', '!']),
|
97 |
+
('-->', ['--', '>']),
|
98 |
+
('->', ['-', '>']),
|
99 |
+
('<-', ['<', '-']),
|
100 |
+
('<--', ['<', '--']),
|
101 |
+
('hello (@person)', ['hello', '(', '@person', ')']),
|
102 |
+
]
|
103 |
+
|
104 |
+
|
105 |
+
def test_normal():
|
106 |
+
""" Normal/combined usage.
|
107 |
+
"""
|
108 |
+
test_base(TESTS_NORMAL)
|
109 |
+
|
110 |
+
|
111 |
+
def test_emojis():
|
112 |
+
""" Tokenizing emojis/emoticons/decorations.
|
113 |
+
"""
|
114 |
+
test_base(TESTS_EMOJIS)
|
115 |
+
|
116 |
+
|
117 |
+
def test_urls():
|
118 |
+
""" Tokenizing URLs.
|
119 |
+
"""
|
120 |
+
test_base(TESTS_URLS)
|
121 |
+
|
122 |
+
|
123 |
+
def test_twitter():
|
124 |
+
""" Tokenizing hashtags, mentions and emails.
|
125 |
+
"""
|
126 |
+
test_base(TESTS_TWITTER)
|
127 |
+
|
128 |
+
|
129 |
+
def test_phone_nums():
|
130 |
+
""" Tokenizing phone numbers.
|
131 |
+
"""
|
132 |
+
test_base(TESTS_PHONE_NUMS)
|
133 |
+
|
134 |
+
|
135 |
+
def test_datetime():
|
136 |
+
""" Tokenizing dates and times.
|
137 |
+
"""
|
138 |
+
test_base(TESTS_DATETIME)
|
139 |
+
|
140 |
+
|
141 |
+
def test_currencies():
|
142 |
+
""" Tokenizing currencies.
|
143 |
+
"""
|
144 |
+
test_base(TESTS_CURRENCIES)
|
145 |
+
|
146 |
+
|
147 |
+
def test_num_sym():
|
148 |
+
""" Tokenizing combinations of numbers and symbols.
|
149 |
+
"""
|
150 |
+
test_base(TESTS_NUM_SYM)
|
151 |
+
|
152 |
+
|
153 |
+
def test_punctuation():
|
154 |
+
""" Tokenizing punctuation and contractions.
|
155 |
+
"""
|
156 |
+
test_base(TESTS_PUNCTUATION)
|
157 |
+
|
158 |
+
|
159 |
+
@nottest
|
160 |
+
def test_base(tests):
|
161 |
+
""" Base function for running tests.
|
162 |
+
"""
|
163 |
+
for (test, expected) in tests:
|
164 |
+
actual = tokenize(test)
|
165 |
+
assert actual == expected, \
|
166 |
+
"Tokenization of \'{}\' failed, expected: {}, actual: {}"\
|
167 |
+
.format(test, expected, actual)
|
tests/test_word_generator.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import sys
|
3 |
+
from os.path import dirname, abspath
|
4 |
+
sys.path.append(dirname(dirname(abspath(__file__))))
|
5 |
+
from nose.tools import raises
|
6 |
+
from torchmoji.word_generator import WordGenerator
|
7 |
+
|
8 |
+
IS_PYTHON2 = int(sys.version[0]) == 2
|
9 |
+
|
10 |
+
@raises(ValueError)
|
11 |
+
def test_only_unicode_accepted():
|
12 |
+
""" Non-Unicode strings raise a ValueError.
|
13 |
+
In Python 3 all string are Unicode
|
14 |
+
"""
|
15 |
+
if not IS_PYTHON2:
|
16 |
+
raise ValueError("You are using python 3 so this test should always pass")
|
17 |
+
|
18 |
+
sentences = [
|
19 |
+
u'Hello world',
|
20 |
+
u'I am unicode',
|
21 |
+
'I am not unicode',
|
22 |
+
]
|
23 |
+
|
24 |
+
wg = WordGenerator(sentences)
|
25 |
+
for w in wg:
|
26 |
+
pass
|
27 |
+
|
28 |
+
|
29 |
+
def test_unicode_sentences_ignored_if_set():
|
30 |
+
""" Strings with Unicode characters tokenize to empty array if they're not allowed.
|
31 |
+
"""
|
32 |
+
sentence = [u'Dobrý den, jak se máš?']
|
33 |
+
wg = WordGenerator(sentence, allow_unicode_text=False)
|
34 |
+
assert wg.get_words(sentence[0]) == []
|
35 |
+
|
36 |
+
|
37 |
+
def test_check_ascii():
|
38 |
+
""" check_ascii recognises ASCII words properly.
|
39 |
+
In Python 3 all string are Unicode
|
40 |
+
"""
|
41 |
+
if not IS_PYTHON2:
|
42 |
+
return
|
43 |
+
|
44 |
+
wg = WordGenerator([])
|
45 |
+
assert wg.check_ascii('ASCII')
|
46 |
+
assert not wg.check_ascii('ščřžýá')
|
47 |
+
assert not wg.check_ascii('❤ ☀ ☆ ☂ ☻ ♞ ☯ ☭ ☢')
|
48 |
+
|
49 |
+
|
50 |
+
def test_convert_unicode_word():
|
51 |
+
""" convert_unicode_word converts Unicode words correctly.
|
52 |
+
"""
|
53 |
+
wg = WordGenerator([], allow_unicode_text=True)
|
54 |
+
|
55 |
+
result = wg.convert_unicode_word(u'č')
|
56 |
+
assert result == (True, u'\u010d'), '{}'.format(result)
|
57 |
+
|
58 |
+
|
59 |
+
def test_convert_unicode_word_ignores_if_set():
|
60 |
+
""" convert_unicode_word ignores Unicode words if set.
|
61 |
+
"""
|
62 |
+
wg = WordGenerator([], allow_unicode_text=False)
|
63 |
+
|
64 |
+
result = wg.convert_unicode_word(u'č')
|
65 |
+
assert result == (False, ''), '{}'.format(result)
|
66 |
+
|
67 |
+
|
68 |
+
def test_convert_unicode_chars():
|
69 |
+
""" convert_unicode_word correctly converts accented characters.
|
70 |
+
"""
|
71 |
+
wg = WordGenerator([], allow_unicode_text=True)
|
72 |
+
result = wg.convert_unicode_word(u'ěščřžýáíé')
|
73 |
+
assert result == (True, u'\u011b\u0161\u010d\u0159\u017e\xfd\xe1\xed\xe9'), '{}'.format(result)
|