File size: 2,150 Bytes
9f98d7f
 
f118d8b
9f98d7f
 
 
 
 
 
 
 
e3a5b4c
9f98d7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cdb392
9f98d7f
 
 
 
 
6cdb392
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import nltk
import spacy
from fastcoref import spacy_component
class TextPreprocessor:
  """
  Class that Preprocesses text for the pipeline
  Converts text by performing coreference, splitting text into postive and negative, then generate CLIP Embeddings.
  """
  def __init__(self):

    nltk.download('punkt')
    spacy.cli.download("en_core_web_sm")
    self.nlp = spacy.load("en_core_web_sm")
    self.nlp.add_pipe(
      "fastcoref", 
      config={'model_architecture': 'LingMessCoref', 'model_path': 'biu-nlp/lingmess-coref', 'device': 'cpu'}
    )
  
  def coref(self, text = None):
    '''
        Does Coreference Resolution
        Parameters:
        text: the input paragraph whose coreference is to be resolved. Default: Alice goes down the rabbit hole. Where she would discover a new reality beyond her expectations.
        
        Returns: 
        Coreference Resolved paragraph

    '''
    if not text:
      text = 'Alice goes down the rabbit hole. Where she would discover a new reality beyond her expectations.'
    doc = self.nlp(
      text, 
      component_cfg={"fastcoref": {'resolve_text': True}}
    )
    # Check doc._.coref_clusters for cluster info
    return doc._.resolved_text
  
  def neg_prompt(self,string : str):
        """
        Splits Text Into Postive an Negative Prompt.
        """
        positive = " "
        negative = " "
        words = nltk.word_tokenize(string)
        for i, word in enumerate(words[:-1]):
            if words[i+1].lower() not in ["n't", 'not']:
                positive += " " + word
            else:
                for wor in words[i+2:]:
                    negative += " " + wor
                return {'pos':positive, 'neg': negative} 
        if(words!=[]):
            positive+=words[-1] 
        return {'pos':positive, 'neg': negative}
  
  def __call__(self, text):
    old_sentences = nltk.sent_tokenize(text)
    coref_text = self.coref(text)
    sentences = nltk.sent_tokenize(coref_text)
    processed_sentences = []
    for sentence in sentences:
        processed_sentences.append(self.neg_prompt(sentence))
    return processed_sentences, old_sentences