langdonholmes commited on
Commit
5c59636
β€’
1 Parent(s): 3fde3db

refactored but still needs stress testing

Browse files
spacy_analyzer.py β†’ analyzer.py RENAMED
@@ -1,36 +1,31 @@
1
-
2
- from presidio_analyzer import (
3
- AnalyzerEngine,
4
- RecognizerResult,
5
- RecognizerRegistry,
6
- LocalRecognizer,
7
- AnalysisExplanation,
8
- )
9
-
10
- from presidio_analyzer.nlp_engine import NlpEngineProvider, NlpArtifacts
11
  from typing import Optional
12
 
13
- import logging
14
- logger = logging.getLogger("presidio-analyzer")
 
 
 
 
15
 
16
  class CustomSpacyRecognizer(LocalRecognizer):
17
  ENTITIES = [
18
- "STUDENT",
19
  ]
20
 
21
- DEFAULT_EXPLANATION = "Identified as {} by a Student Name Detection Model"
22
 
23
  CHECK_LABEL_GROUPS = [
24
- ({"STUDENT"}, {"STUDENT"}),
25
  ]
26
 
27
  MODEL_LANGUAGES = {
28
- "en": "langdonholmes/en_student_name_detector",
29
  }
30
 
31
  def __init__(
32
  self,
33
- supported_language: str = "en",
34
  supported_entities: Optional[list[str]] = None,
35
  check_label_groups: Optional[tuple[set, set]] = None,
36
  ner_strength: float = 0.85,
@@ -46,25 +41,25 @@ class CustomSpacyRecognizer(LocalRecognizer):
46
  )
47
 
48
  def load(self) -> None:
49
- """Load the model, not used. Model is loaded during initialization."""
50
  pass
51
 
52
  def get_supported_entities(self) -> list[str]:
53
- """
54
  Return supported entities by this model.
55
  :return: List of the supported entities.
56
- """
57
  return self.supported_entities
58
 
59
  def build_spacy_explanation(
60
  self, original_score: float, explanation: str
61
  ) -> AnalysisExplanation:
62
- """
63
  Create explanation for why this result was detected.
64
  :param original_score: Score given by this recognizer
65
  :param explanation: Explanation string
66
  :return:
67
- """
68
  explanation = AnalysisExplanation(
69
  recognizer=self.__class__.__name__,
70
  original_score=original_score,
@@ -76,15 +71,15 @@ class CustomSpacyRecognizer(LocalRecognizer):
76
  text: str,
77
  entities: list[str] = None,
78
  nlp_artifacts: NlpArtifacts = None):
79
- """Analyze input using Analyzer engine and input arguments (kwargs)."""
80
 
81
- if not entities or "All" in entities:
82
  entities = None
83
 
84
  results = []
85
 
86
  if not nlp_artifacts:
87
- logger.warning("Skipping SpaCy, nlp artifacts not provided...")
88
  return results
89
 
90
  ner_entities = nlp_artifacts.entities
@@ -123,7 +118,7 @@ class CustomSpacyRecognizer(LocalRecognizer):
123
  )
124
 
125
  def prepare_analyzer(configuration):
126
- """Handle Preparation of Analyzer Engine for Presidio."""
127
 
128
  spacy_recognizer = CustomSpacyRecognizer()
129
 
@@ -137,10 +132,10 @@ def prepare_analyzer(configuration):
137
  registry.add_recognizer(spacy_recognizer)
138
 
139
  # remove the nlp engine we passed, to use custom label mappings
140
- registry.remove_recognizer("SpacyRecognizer")
141
 
142
  analyzer = AnalyzerEngine(nlp_engine=nlp_engine,
143
  registry=registry,
144
- supported_languages=["en"])
145
 
146
  return analyzer
 
1
+ import logging
 
 
 
 
 
 
 
 
 
2
  from typing import Optional
3
 
4
+ from presidio_analyzer import (AnalysisExplanation, AnalyzerEngine,
5
+ LocalRecognizer, RecognizerRegistry,
6
+ RecognizerResult)
7
+ from presidio_analyzer.nlp_engine import NlpArtifacts, NlpEngineProvider
8
+
9
+ logger = logging.getLogger('presidio-analyzer')
10
 
11
  class CustomSpacyRecognizer(LocalRecognizer):
12
  ENTITIES = [
13
+ 'STUDENT',
14
  ]
15
 
16
+ DEFAULT_EXPLANATION = 'Identified as {} by a Student Name Detection Model'
17
 
18
  CHECK_LABEL_GROUPS = [
19
+ ({'STUDENT'}, {'STUDENT'}),
20
  ]
21
 
22
  MODEL_LANGUAGES = {
23
+ 'en': 'langdonholmes/en_student_name_detector',
24
  }
25
 
26
  def __init__(
27
  self,
28
+ supported_language: str = 'en',
29
  supported_entities: Optional[list[str]] = None,
30
  check_label_groups: Optional[tuple[set, set]] = None,
31
  ner_strength: float = 0.85,
 
41
  )
42
 
43
  def load(self) -> None:
44
+ '''Load the model, not used. Model is loaded during initialization.'''
45
  pass
46
 
47
  def get_supported_entities(self) -> list[str]:
48
+ '''
49
  Return supported entities by this model.
50
  :return: List of the supported entities.
51
+ '''
52
  return self.supported_entities
53
 
54
  def build_spacy_explanation(
55
  self, original_score: float, explanation: str
56
  ) -> AnalysisExplanation:
57
+ '''
58
  Create explanation for why this result was detected.
59
  :param original_score: Score given by this recognizer
60
  :param explanation: Explanation string
61
  :return:
62
+ '''
63
  explanation = AnalysisExplanation(
64
  recognizer=self.__class__.__name__,
65
  original_score=original_score,
 
71
  text: str,
72
  entities: list[str] = None,
73
  nlp_artifacts: NlpArtifacts = None):
74
+ '''Analyze input using Analyzer engine and input arguments (kwargs).'''
75
 
76
+ if not entities or 'All' in entities:
77
  entities = None
78
 
79
  results = []
80
 
81
  if not nlp_artifacts:
82
+ logger.warning('Skipping SpaCy, nlp artifacts not provided...')
83
  return results
84
 
85
  ner_entities = nlp_artifacts.entities
 
118
  )
119
 
120
  def prepare_analyzer(configuration):
121
+ '''Handle Preparation of Analyzer Engine for Presidio.'''
122
 
123
  spacy_recognizer = CustomSpacyRecognizer()
124
 
 
132
  registry.add_recognizer(spacy_recognizer)
133
 
134
  # remove the nlp engine we passed, to use custom label mappings
135
+ registry.remove_recognizer('SpacyRecognizer')
136
 
137
  analyzer = AnalyzerEngine(nlp_engine=nlp_engine,
138
  registry=registry,
139
+ supported_languages=['en'])
140
 
141
  return analyzer
anonymizer.py CHANGED
@@ -1,24 +1,55 @@
 
1
  from presidio_anonymizer import AnonymizerEngine
2
  from presidio_anonymizer.entities import OperatorConfig
3
- from presidio_analyzer import RecognizerResult
4
 
5
- def retrieve_name_records():
6
- """Read in a table of names with gender and country code fields."""
7
- pass
8
 
9
- def generate_surrogate(name):
10
- """Return appropriate surrogate name from text string"""
11
- if "John" in name:
12
- return "Jill"
13
- else:
14
- return "SURROGATE_NAME"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def anonymize(
17
  anonymizer: AnonymizerEngine,
18
  text: str,
19
  analyze_results: list[RecognizerResult]
20
  ):
21
- """Anonymize identified input using Presidio Anonymizer."""
22
 
23
  if not text:
24
  return
@@ -27,11 +58,18 @@ def anonymize(
27
  text,
28
  analyze_results,
29
  operators={
30
- "STUDENT": OperatorConfig("custom", {"lambda": generate_surrogate}),
31
- "EMAIL_ADDRESS": OperatorConfig("replace", {"new_value": "[email protected]"}),
32
- "PHONE_NUMBER": OperatorConfig("replace", {"new_value": "888-888-8888"}),
33
- "URL": OperatorConfig("replace", {"new_value": "aol.com"}),
 
 
 
 
34
  }
35
  )
36
 
37
- return res.text
 
 
 
 
1
+ from presidio_analyzer import RecognizerResult
2
  from presidio_anonymizer import AnonymizerEngine
3
  from presidio_anonymizer.entities import OperatorConfig
 
4
 
5
+ from names_database import NameDatabase
6
+
7
+ names_db = NameDatabase()
8
 
9
+ def split_name(original_name: str):
10
+ '''Splits name into parts.
11
+ If one token, assume it is a first name.
12
+ If two tokens, first and last name.
13
+ If three tokens, one first name and two last names.
14
+ If four tokens, two first names and two last names.'''
15
+ match original_name.split():
16
+ case [first]:
17
+ return first, None
18
+ case [first, last]:
19
+ return first, last
20
+ case [first, last_1, last_2]:
21
+ return first, ' '.join((last_1, last_2))
22
+ case [first_1, first_2, last_1, last_2]:
23
+ return ' '.join((first_1, first_2)), ' '.join((last_1, last_2))
24
+ case _:
25
+ return None, None
26
+
27
+ def generate_surrogate(original_name: str):
28
+ '''Generate a surrogate name.
29
+ '''
30
+ first_names, last_names = split_name(original_name)
31
+ gender = names_db.get_gender(first_names) if first_names else None
32
+ country = names_db.get_country(last_names) if last_names else None
33
+
34
+ surrogate_name = ''
35
+
36
+ name_candidates = names_db.get_random_name(
37
+ gender=gender,
38
+ country=country)
39
+
40
+ surrogate_name += name_candidates.iloc[0]['first']
41
+
42
+ if last_names:
43
+ surrogate_name += ' ' + name_candidates.iloc[1]['last']
44
+
45
+ return surrogate_name
46
 
47
  def anonymize(
48
  anonymizer: AnonymizerEngine,
49
  text: str,
50
  analyze_results: list[RecognizerResult]
51
  ):
52
+ '''Anonymize identified input using Presidio Anonymizer.'''
53
 
54
  if not text:
55
  return
 
58
  text,
59
  analyze_results,
60
  operators={
61
+ 'STUDENT': OperatorConfig('custom',
62
+ {'lambda': generate_surrogate}),
63
+ 'EMAIL_ADDRESS': OperatorConfig('replace',
64
+ {'new_value': 'janedoe@aol.com'}),
65
+ 'PHONE_NUMBER': OperatorConfig('replace',
66
+ {'new_value': '888-888-8888'}),
67
+ 'URL': OperatorConfig('replace',
68
+ {'new_value': 'aol.com'}),
69
  }
70
  )
71
 
72
+ return res.text
73
+
74
+ if __name__ == '__main__':
75
+ print(generate_surrogate('Nora Wang'))
app.py CHANGED
@@ -1,7 +1,7 @@
1
 
2
- """Streamlit app for Student Name Detection models."""
3
 
4
- from spacy_analyzer import prepare_analyzer
5
  from anonymizer import anonymize
6
  from presidio_anonymizer import AnonymizerEngine
7
  import pandas as pd
@@ -11,18 +11,18 @@ import json
11
  import warnings
12
  import streamlit as st
13
  import os
14
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
15
  warnings.filterwarnings('ignore')
16
 
17
  # Helper methods
18
  @st.cache(allow_output_mutation=True)
19
  def analyzer_engine():
20
- """Return AnalyzerEngine and cache with Streamlit."""
21
 
22
  configuration = {
23
- "nlp_engine_name": "spacy",
24
- "models": [
25
- {"lang_code": "en", "model_name": "en_student_name_detector"}],
26
  }
27
 
28
  analyzer = prepare_analyzer(configuration)
@@ -31,7 +31,7 @@ def analyzer_engine():
31
 
32
  @st.cache(allow_output_mutation=True)
33
  def anonymizer_engine():
34
- """Return AnonymizerEngine."""
35
  return AnonymizerEngine()
36
 
37
  def annotate(text, st_analyze_results, st_entities):
@@ -54,57 +54,57 @@ def annotate(text, st_analyze_results, st_entities):
54
  return tokens
55
 
56
 
57
- st.set_page_config(page_title="Student Name Detector (English)", layout="wide")
58
 
59
  # Side bar
60
  st.sidebar.markdown(
61
- """Detect and anonymize PII in text using an [NLP model](https://huggingface.co/langdonholmes/en_student_name_detector) [trained](https://github.com/aialoe/deidentification-pipeline) on student-generated text collected by Coursera.
62
- """
63
  )
64
 
65
  st_entities = st.sidebar.multiselect(
66
- label="Which entities to look for?",
67
  options=analyzer_engine().get_supported_entities(),
68
  default=list(analyzer_engine().get_supported_entities()),
69
  )
70
 
71
  st_threshold = st.sidebar.slider(
72
- label="Acceptance threshold", min_value=0.0, max_value=1.0, value=0.35
73
  )
74
 
75
  st_return_decision_process = st.sidebar.checkbox(
76
- "Add analysis explanations in json")
77
 
78
  st.sidebar.info(
79
- "This is part of a deidentification project for student-generated text."
80
  )
81
 
82
  # Main panel
83
  analyzer_load_state = st.info(
84
- "Starting Presidio analyzer and loading Longformer-based model...")
85
  engine = analyzer_engine()
86
  analyzer_load_state.empty()
87
 
88
 
89
  st_text = st.text_area(
90
- label="Type in some text",
91
- value="Learning Reflection\n\nWritten by John Williams and Samantha Morales\n\nIn this course I learned many things. As Liedtke (2004) said, \"Students grow when they learn\" (Erickson et al. 1998).\n\nBy John H. Williams -- (714) 328-9989 -- [email protected]",
92
  height=200,
93
  )
94
 
95
- button = st.button("Detect PII")
96
 
97
  if 'first_load' not in st.session_state:
98
  st.session_state['first_load'] = True
99
 
100
  # After
101
- st.subheader("Analyzed")
102
- with st.spinner("Analyzing..."):
103
  if button or st.session_state.first_load:
104
  st_analyze_results = analyzer_engine().analyze(
105
  text=st_text,
106
  entities=st_entities,
107
- language="en",
108
  score_threshold=st_threshold,
109
  return_decision_process=st_return_decision_process,
110
  )
@@ -113,11 +113,11 @@ with st.spinner("Analyzing..."):
113
  annotated_text(*annotated_tokens)
114
 
115
  # vertical space
116
- st.text("")
117
 
118
- st.subheader("Anonymized")
119
 
120
- with st.spinner("Anonymizing..."):
121
  if button or st.session_state.first_load:
122
  st_anonymize_results = anonymize(anonymizer_engine(),
123
  st_text,
@@ -125,34 +125,34 @@ with st.spinner("Anonymizing..."):
125
  st_anonymize_results
126
 
127
  # table result
128
- st.subheader("Detailed Findings")
129
  if st_analyze_results:
130
  res_dicts = [r.to_dict() for r in st_analyze_results]
131
  for d in res_dicts:
132
  d['Value'] = st_text[d['start']:d['end']]
133
  df = pd.DataFrame.from_records(res_dicts)
134
- df = df[["entity_type", "Value", "score", "start", "end"]].rename(
135
  {
136
- "entity_type": "Entity type",
137
- "start": "Start",
138
- "end": "End",
139
- "score": "Confidence",
140
  },
141
  axis=1,
142
  )
143
 
144
  st.dataframe(df, width=1000)
145
  else:
146
- st.text("No findings")
147
 
148
  st.session_state['first_load'] = True
149
 
150
  # json result
151
  class ToDictListEncoder(JSONEncoder):
152
- """Encode dict to json."""
153
 
154
  def default(self, o):
155
- """Encode to JSON using to_dict."""
156
  if o:
157
  return o.to_dict()
158
  return []
 
1
 
2
+ '''Streamlit app for Student Name Detection models.'''
3
 
4
+ from analyzer import prepare_analyzer
5
  from anonymizer import anonymize
6
  from presidio_anonymizer import AnonymizerEngine
7
  import pandas as pd
 
11
  import warnings
12
  import streamlit as st
13
  import os
14
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
15
  warnings.filterwarnings('ignore')
16
 
17
  # Helper methods
18
  @st.cache(allow_output_mutation=True)
19
  def analyzer_engine():
20
+ '''Return AnalyzerEngine and cache with Streamlit.'''
21
 
22
  configuration = {
23
+ 'nlp_engine_name': 'spacy',
24
+ 'models': [
25
+ {'lang_code': 'en', 'model_name': 'en_student_name_detector'}],
26
  }
27
 
28
  analyzer = prepare_analyzer(configuration)
 
31
 
32
  @st.cache(allow_output_mutation=True)
33
  def anonymizer_engine():
34
+ '''Return AnonymizerEngine.'''
35
  return AnonymizerEngine()
36
 
37
  def annotate(text, st_analyze_results, st_entities):
 
54
  return tokens
55
 
56
 
57
+ st.set_page_config(page_title='Student Name Detector (English)', layout='wide')
58
 
59
  # Side bar
60
  st.sidebar.markdown(
61
+ '''Detect and anonymize PII in text using an [NLP model](https://huggingface.co/langdonholmes/en_student_name_detector) [trained](https://github.com/aialoe/deidentification-pipeline) on student-generated text collected by Coursera.
62
+ '''
63
  )
64
 
65
  st_entities = st.sidebar.multiselect(
66
+ label='Which entities to look for?',
67
  options=analyzer_engine().get_supported_entities(),
68
  default=list(analyzer_engine().get_supported_entities()),
69
  )
70
 
71
  st_threshold = st.sidebar.slider(
72
+ label='Acceptance threshold', min_value=0.0, max_value=1.0, value=0.35
73
  )
74
 
75
  st_return_decision_process = st.sidebar.checkbox(
76
+ 'Add analysis explanations in json')
77
 
78
  st.sidebar.info(
79
+ 'This is part of a deidentification project for student-generated text.'
80
  )
81
 
82
  # Main panel
83
  analyzer_load_state = st.info(
84
+ 'Starting Presidio analyzer and loading Longformer-based model...')
85
  engine = analyzer_engine()
86
  analyzer_load_state.empty()
87
 
88
 
89
  st_text = st.text_area(
90
+ label='Type in some text',
91
+ value='Learning Reflection\n\nWritten by John Williams and Samantha Morales\n\nIn this course I learned many things. As Liedtke (2004) said, \"Students grow when they learn\" (Erickson et al. 1998).\n\nBy John H. Williams -- (714) 328-9989 -- [email protected]',
92
  height=200,
93
  )
94
 
95
+ button = st.button('Detect PII')
96
 
97
  if 'first_load' not in st.session_state:
98
  st.session_state['first_load'] = True
99
 
100
  # After
101
+ st.subheader('Analyzed')
102
+ with st.spinner('Analyzing...'):
103
  if button or st.session_state.first_load:
104
  st_analyze_results = analyzer_engine().analyze(
105
  text=st_text,
106
  entities=st_entities,
107
+ language='en',
108
  score_threshold=st_threshold,
109
  return_decision_process=st_return_decision_process,
110
  )
 
113
  annotated_text(*annotated_tokens)
114
 
115
  # vertical space
116
+ st.text('')
117
 
118
+ st.subheader('Anonymized')
119
 
120
+ with st.spinner('Anonymizing...'):
121
  if button or st.session_state.first_load:
122
  st_anonymize_results = anonymize(anonymizer_engine(),
123
  st_text,
 
125
  st_anonymize_results
126
 
127
  # table result
128
+ st.subheader('Detailed Findings')
129
  if st_analyze_results:
130
  res_dicts = [r.to_dict() for r in st_analyze_results]
131
  for d in res_dicts:
132
  d['Value'] = st_text[d['start']:d['end']]
133
  df = pd.DataFrame.from_records(res_dicts)
134
+ df = df[['entity_type', 'Value', 'score', 'start', 'end']].rename(
135
  {
136
+ 'entity_type': 'Entity type',
137
+ 'start': 'Start',
138
+ 'end': 'End',
139
+ 'score': 'Confidence',
140
  },
141
  axis=1,
142
  )
143
 
144
  st.dataframe(df, width=1000)
145
  else:
146
+ st.text('No findings')
147
 
148
  st.session_state['first_load'] = True
149
 
150
  # json result
151
  class ToDictListEncoder(JSONEncoder):
152
+ '''Encode dict to json.'''
153
 
154
  def default(self, o):
155
+ '''Encode to JSON using to_dict.'''
156
  if o:
157
  return o.to_dict()
158
  return []
data/{ascii_fb_names_small.parquet β†’ ascii_names.parquet} RENAMED
File without changes
match_replace.py DELETED
@@ -1,117 +0,0 @@
1
- import pandas as pd
2
-
3
- from names_database import NameDatabase
4
-
5
- names_db = NameDatabase
6
-
7
- def describe_name(first_names, last_names):
8
- gender = names_db.get_gender() if first_names else None
9
- country = names_db.get_country() if last_names else None
10
- return gender, country
11
-
12
- def split_name(all_names):
13
- '''Splits name into parts.
14
- If one token, assume it is a first name.
15
- If two tokens, first and last name.
16
- If three tokens, one first name and two last names.
17
- If four tokens, two first names and two last names.'''
18
- match all_names.split():
19
- case [first]:
20
- return first, None
21
- case [first, last]:
22
- return first, last
23
- case [first, last_1, last_2]:
24
- return first, ' '.join((last_1, last_2))
25
- case [first_1, first_2, last_1, last_2]:
26
- return ' '.join((first_1, first_2)), ' '.join((last_1, last_2))
27
- case _:
28
- return None, None
29
-
30
- def match_name(original_name):
31
- # FIXME: take too LONG time to run (large df used multi-times), how to improve
32
- # FIXME: here we only keep the first name for now
33
- # TODO: how to match both first and last? -- first name match gender, last name match country?
34
- # gender is not applied to last name
35
- # the name distinguished by first and last?
36
- # FIXME: since it is completely random, the same original name may be diff after replacing. How to know whether the two names is the same person?
37
- first_name = original_name.split()[0]
38
- global fb_df
39
- names = fb_df[fb_df['first']==first_name]
40
- if not names.empty:
41
- name_df = names.sample(n=1)
42
- # prevent for same name - deleting same name from df
43
- new_df = fb_df[fb_df['first'] != first_name]
44
- new_name = replace_name(name_df, new_df)
45
- return new_name
46
- else:
47
- return 'Jane Doe'
48
-
49
- def replace_name(name_df, new_df):
50
- """
51
- :param name_df: df that match the original first name -> data frame
52
- :param new_df: df that does not repeat with original name
53
- :return: whole name: that match country & gender -> str
54
- """
55
- gender = name_df['gender'].to_string(index=False)
56
- country = name_df['country'].to_string(index=False)
57
-
58
- # match country, then match gender
59
- country_df = new_df[new_df['country'] == country]
60
- country_g_df = country_df[country_df['gender'] == gender]
61
-
62
- first = country_g_df['first'].sample(n=1).to_string(index=False)
63
- last = country_g_df['last'].sample(n=1).to_string(index=False)
64
- return first+' '+last
65
-
66
-
67
-
68
- def match_name_2(original_name):
69
- """
70
- Work by match gender from first name, match country from the last name
71
- :param original_name:
72
- :return:
73
- """
74
- global fb_df
75
- fb_df = pd.read_parquet('ascii_fb_names_small.parquet')
76
- # FIXME: work when get a full name, may need branch to only first or last name....
77
- gender = name_match_gender(original_name.split()[0])
78
- print(original_name.split()[1])
79
- country = name_match_country(original_name.split()[-1])
80
- return replace_name_2(gender, country)
81
-
82
-
83
- def name_match_country(last_name):
84
- names = fb_df[fb_df['last'] == last_name]
85
- if not names.empty:
86
- country = names['country'].sample(n=1).to_string(index=False)
87
- return country
88
- else:
89
- return 'US'
90
-
91
- def name_match_gender(first_name):
92
- names = fb_df[fb_df['first'] == first_name]
93
- gender = names['gender'].sample(n=1).to_string(index=False)
94
- return gender
95
-
96
- def replace_name_2(gender, country):
97
- # TODO: prevent same name
98
- country_df = fb_df[fb_df['country'] == country]
99
- country_g_df = country_df[country_df['gender'] == gender]
100
-
101
- first = country_g_df['first'].sample(n=1).to_string(index=False)
102
- last = country_g_df['last'].sample(n=1).to_string(index=False)
103
- full_name = first +' ' + last
104
- return full_name
105
-
106
- def replace_text(str_list):
107
- surrogate_text = ''
108
- for i in str_list:
109
- if isinstance(i, tuple):
110
- i = match_entity(i[0], i[1])
111
- surrogate_text += i
112
- return surrogate_text
113
-
114
- if __name__ == "__main__":
115
- fb_df = pd.read_parquet('ascii_fb_names_small.parquet')
116
- # print(matching("PH", 'female', 'first', 'Momo', fb_df))
117
- print(match_entity('Nora Wang', 'STUDENT'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
names_database.py CHANGED
@@ -1,10 +1,15 @@
1
- from names_dataset import NameDataset, NameWrapper
2
  from typing import Optional
3
 
 
 
 
 
 
4
  class NameDatabase(NameDataset):
5
  def __init__(self) -> None:
6
  super().__init__()
7
- self.names = pd.read_parquet('ascii_fb_names_small.parquet')
8
 
9
  def get_random_name(
10
  self,
@@ -12,17 +17,26 @@ class NameDatabase(NameDataset):
12
  gender: Optional[str] = None
13
  ):
14
  '''country: ISO country code in 'alpha 2' format
15
- gender: "M" or "F"
 
16
  '''
17
  names_view = self.names
18
  if country:
19
  names_view = names_view[names_view['country'] == country]
20
  if gender:
21
  names_view = names_view[names_view['gender'] == gender]
22
- return names_view.sample(weights=names_view.count)
 
 
23
 
24
- def get_gender(first_names: str):
 
 
 
 
 
 
25
  return NameWrapper(self.search(first_names)).gender
26
-
27
- def get_country(last_names: str):
28
  return NameWrapper(self.search(last_names)).country
 
1
+ from pathlib import Path
2
  from typing import Optional
3
 
4
+ import pandas as pd
5
+ from names_dataset import NameDataset, NameWrapper
6
+
7
+ name_table = Path('data', 'ascii_names.parquet')
8
+
9
  class NameDatabase(NameDataset):
10
  def __init__(self) -> None:
11
  super().__init__()
12
+ self.names = pd.read_parquet(name_table)
13
 
14
  def get_random_name(
15
  self,
 
17
  gender: Optional[str] = None
18
  ):
19
  '''country: ISO country code in 'alpha 2' format
20
+ gender: 'M' or 'F'
21
+ returns two rows of the names dataframe
22
  '''
23
  names_view = self.names
24
  if country:
25
  names_view = names_view[names_view['country'] == country]
26
  if gender:
27
  names_view = names_view[names_view['gender'] == gender]
28
+ if names_view.size < 25:
29
+ return self.names.sample(n=2, weights=self.names['count'])
30
+ return names_view.sample(n=2, weights=names_view['count'])
31
 
32
+ def search(self, name: str):
33
+ key = name.strip().title()
34
+ fn = self.first_names.get(key) if self.first_names is not None else None
35
+ ln = self.last_names.get(key) if self.last_names is not None else None
36
+ return {'first_name': fn, 'last_name': ln}
37
+
38
+ def get_gender(self, first_names: str):
39
  return NameWrapper(self.search(first_names)).gender
40
+
41
+ def get_country(self, last_names: str):
42
  return NameWrapper(self.search(last_names)).country