ddiddu commited on
Commit
1209aa3
·
1 Parent(s): 9107f8d

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +59 -0
inference.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
2
+ import torch
3
+
4
+ # Load model and tokenizer
5
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
6
+ model = GPT2LMHeadModel.from_pretrained("gpt2")
7
+
8
+ # Define gender predictions for specific characters
9
+ character_gender_mapping = {
10
+ "NARRATOR": "neutral",
11
+ "FATHER": "male",
12
+ "HARPER": "female"
13
+ }
14
+
15
+ def predict_gender_aggregated(character, lines):
16
+ # Check if the character is in the mapping
17
+ if character.upper() in character_gender_mapping:
18
+ return character_gender_mapping[character.upper()]
19
+
20
+ # For other characters, perform gender prediction as before
21
+ aggregated_text = " ".join(lines)
22
+ input_text = f"Character: {character}. Dialogue: {aggregated_text}. Gender:"
23
+ input_ids = tokenizer.encode(input_text, return_tensors='pt')
24
+
25
+ # Create an attention mask
26
+ attention_mask = torch.ones(input_ids.shape)
27
+
28
+ output = model.generate(input_ids, attention_mask=attention_mask, max_length=60, do_sample=True, temperature=0.7)
29
+ result = tokenizer.decode(output[0], skip_special_tokens=True)
30
+
31
+ # Extract gender prediction as 'male' or 'female' (assuming it's one of these two)
32
+ if 'male' in result.lower():
33
+ gender_prediction = 'male'
34
+ elif 'female' in result.lower():
35
+ gender_prediction = 'female'
36
+ else:
37
+ gender_prediction = 'unknown' # Handle cases where gender isn't explicitly mentioned
38
+
39
+ return gender_prediction
40
+
41
+ # This function will be called for inference
42
+ def predict(input_data):
43
+ character = input_data.get("character")
44
+ lines = input_data.get("lines")
45
+
46
+ # Error handling for missing input
47
+ if not character or not lines:
48
+ return {"error": "Missing character or lines in the input"}
49
+
50
+ gender_prediction = predict_gender_aggregated(character, lines)
51
+ return {"character": character, "predicted_gender": gender_prediction}
52
+
53
+ # Example input format for testing locally
54
+ if __name__ == "__main__":
55
+ test_input = {
56
+ "character": "FATHER",
57
+ "lines": ["I am very proud of you, son."]
58
+ }
59
+ print(predict(test_input))