Spaces:
Runtime error
Runtime error
anthony.galtier
commited on
Commit
·
06a851e
1
Parent(s):
ec36502
Added light code files
Browse files- bert/model.py +26 -0
- bert/performance.py +24 -0
- bert/preprocess_text.py +62 -0
- bert/tokenize.py +29 -0
- requirements.txt +7 -0
- text_to_price.py +52 -0
bert/model.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from transformers import CamembertModel
|
3 |
+
|
4 |
+
|
5 |
+
class CamembertRegressor(nn.Module):
|
6 |
+
|
7 |
+
def __init__(self, drop_rate=0.2, freeze_camembert=True):
|
8 |
+
|
9 |
+
super(CamembertRegressor, self).__init__()
|
10 |
+
D_in, D_out = 768, 1
|
11 |
+
|
12 |
+
self.camembert = CamembertModel.from_pretrained('camembert-base')
|
13 |
+
self.regressor = nn.Sequential(
|
14 |
+
nn.Dropout(drop_rate),
|
15 |
+
nn.Linear(D_in, D_out))
|
16 |
+
|
17 |
+
if freeze_camembert:
|
18 |
+
for param in self.camembert.parameters():
|
19 |
+
param.requires_grad = False
|
20 |
+
|
21 |
+
def forward(self, input_ids, attention_masks):
|
22 |
+
|
23 |
+
outputs = self.camembert(input_ids, attention_masks)
|
24 |
+
outputs_cls = outputs[1]
|
25 |
+
outputs = self.regressor(outputs_cls)
|
26 |
+
return outputs
|
bert/performance.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from bert.tokenize import extract_inputs_masks, tokenize_encode_corpus
|
3 |
+
from torch.utils.data import TensorDataset, DataLoader
|
4 |
+
|
5 |
+
|
6 |
+
def predict(samples, tokenizer, scaler, model, device, max_len, batch_size,
|
7 |
+
return_scaled=False):
|
8 |
+
|
9 |
+
model.eval()
|
10 |
+
encoded_corpus = tokenize_encode_corpus(tokenizer, samples, max_len)
|
11 |
+
input_ids, attention_mask = extract_inputs_masks(encoded_corpus)
|
12 |
+
input_ids = torch.tensor([input_ids]).to(device)[0]
|
13 |
+
attention_mask = torch.tensor([attention_mask]).to(device)[0]
|
14 |
+
dataset = TensorDataset(input_ids, attention_mask)
|
15 |
+
dataloader = DataLoader(dataset, batch_size)
|
16 |
+
output = []
|
17 |
+
for batch in dataloader:
|
18 |
+
batch_inputs, batch_masks = tuple(b.to(device) for b in batch)
|
19 |
+
with torch.no_grad():
|
20 |
+
output += model(batch_inputs, batch_masks).view(1,-1).tolist()[0]
|
21 |
+
if return_scaled:
|
22 |
+
return output
|
23 |
+
output = scaler.inverse_transform([output])
|
24 |
+
return output.reshape(1,-1).tolist()[0]
|
bert/preprocess_text.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import re
|
3 |
+
|
4 |
+
|
5 |
+
def treat_euro(text):
|
6 |
+
text = re.sub(r'(euro[^s])|(euros)|(€)', ' euros', text)
|
7 |
+
return text
|
8 |
+
|
9 |
+
|
10 |
+
def treat_m2(text):
|
11 |
+
text = re.sub(r'(m2)|(m²)', ' m²', text)
|
12 |
+
return text
|
13 |
+
|
14 |
+
|
15 |
+
def filter_phone_numbers(text):
|
16 |
+
pattern = r'(?:(?:\+|00)33[\s.-]{0,3}(?:\(0\)[\s.-]{0,3})?|0)[1-9](?:(?:[\s.-]?\d{2}){4}|\d{2}(?:[\s.-]?\d{3}){2})|(\d{2}[ ]\d{2}[ ]\d{3}[ ]\d{3})'
|
17 |
+
text = re.sub(pattern, '', text)
|
18 |
+
return text
|
19 |
+
|
20 |
+
|
21 |
+
def filter_ibans(text):
|
22 |
+
pattern = r'fr\d{2}[ ]\d{4}[ ]\d{4}[ ]\d{4}[ ]\d{4}[ ]\d{2}|fr\d{20}|fr[ ]\d{2}[ ]\d{3}[ ]\d{3}[ ]\d{3}[ ]\d{5}'
|
23 |
+
text = re.sub(pattern, '', text)
|
24 |
+
return text
|
25 |
+
|
26 |
+
|
27 |
+
def remove_space_between_numbers(text):
|
28 |
+
text = re.sub(r'(\d)\s+(\d)', r'\1\2', text)
|
29 |
+
return text
|
30 |
+
|
31 |
+
|
32 |
+
def filter_emails(text):
|
33 |
+
pattern = r'(?:(?!.*?[.]{2})[a-zA-Z0-9](?:[a-zA-Z0-9.+!%-]{1,64}|)|\"[a-zA-Z0-9.+!% -]{1,64}\")@[a-zA-Z0-9][a-zA-Z0-9.-]+(.[a-z]{2,}|.[0-9]{1,})'
|
34 |
+
text = re.sub(pattern, '', text)
|
35 |
+
return text
|
36 |
+
|
37 |
+
|
38 |
+
def filter_ref(text):
|
39 |
+
pattern = r'(\(*)(ref|réf)(\.|[ ])\d+(\)*)'
|
40 |
+
text = re.sub(pattern, '', text)
|
41 |
+
return text
|
42 |
+
|
43 |
+
|
44 |
+
def filter_websites(text):
|
45 |
+
pattern = r'(http\:\/\/|https\:\/\/)?([a-z0-9][a-z0-9\-]*\.)+[a-z][a-z\-]*'
|
46 |
+
text = re.sub(pattern, '', text)
|
47 |
+
return text
|
48 |
+
|
49 |
+
|
50 |
+
def preprocess_text_for_camembert(text):
|
51 |
+
text = text.lower()
|
52 |
+
text = text.replace(u'\xa0', u' ')
|
53 |
+
text = treat_m2(text)
|
54 |
+
text = treat_euro(text)
|
55 |
+
text = filter_phone_numbers(text)
|
56 |
+
text = filter_emails(text)
|
57 |
+
text = filter_ibans(text)
|
58 |
+
text = filter_ref(text)
|
59 |
+
text = filter_websites(text)
|
60 |
+
text = remove_space_between_numbers(text)
|
61 |
+
return text
|
62 |
+
|
bert/tokenize.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import CamembertTokenizer
|
2 |
+
|
3 |
+
|
4 |
+
def get_tokenizer(model_name='camembert-base'):
|
5 |
+
tokenizer = CamembertTokenizer.from_pretrained(model_name)
|
6 |
+
return tokenizer
|
7 |
+
|
8 |
+
|
9 |
+
def tokenize_encode_corpus(tokenizer, descriptions, max_len):
|
10 |
+
|
11 |
+
encoded_corpus = tokenizer(text=descriptions,
|
12 |
+
add_special_tokens=True,
|
13 |
+
padding='max_length',
|
14 |
+
truncation='longest_first',
|
15 |
+
max_length=max_len,
|
16 |
+
return_attention_mask=True)
|
17 |
+
return encoded_corpus
|
18 |
+
|
19 |
+
|
20 |
+
def extract_inputs_masks(encoded_corpus):
|
21 |
+
|
22 |
+
try:
|
23 |
+
input_ids = encoded_corpus['input_ids']
|
24 |
+
attention_mask = encoded_corpus['attention_mask']
|
25 |
+
except:
|
26 |
+
print('Available keys are = ', encoded_corpus.keys())
|
27 |
+
return None
|
28 |
+
return input_ids, attention_mask
|
29 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.21.6
|
2 |
+
scikit-learn==1.0.2
|
3 |
+
torch==1.12.1
|
4 |
+
transformers==4.21.3
|
5 |
+
sentencepiece==0.1.97
|
6 |
+
streamlit==1.12.2
|
7 |
+
Babel==2.10.3
|
text_to_price.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from babel.numbers import format_currency
|
5 |
+
|
6 |
+
from bert.tokenize import get_tokenizer
|
7 |
+
from bert.model import CamembertRegressor
|
8 |
+
from bert.performance import predict
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
MODEL_STATE_DICT_PATH = './bert/trained_model/model_epoch_5.pt'
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
# ENVRIONMENT SET UP
|
18 |
+
if torch.cuda.is_available():
|
19 |
+
device = torch.device("cuda")
|
20 |
+
else:
|
21 |
+
device = torch.device("cpu")
|
22 |
+
|
23 |
+
|
24 |
+
# MODEL LOADING
|
25 |
+
saved_model_dict = torch.load(MODEL_STATE_DICT_PATH, map_location=device)
|
26 |
+
model = CamembertRegressor()
|
27 |
+
model.load_state_dict(saved_model_dict['model_state_dict'])
|
28 |
+
|
29 |
+
tokenizer = get_tokenizer()
|
30 |
+
max_len = saved_model_dict['max_input_len']
|
31 |
+
scaler = saved_model_dict['labels_scaler']
|
32 |
+
|
33 |
+
|
34 |
+
# WEB APP
|
35 |
+
st.title("Text 2 Price - Real Estate")
|
36 |
+
st.markdown("")
|
37 |
+
|
38 |
+
|
39 |
+
example_description = "Superbe maison de 500m2 à Pétaouchnok..."
|
40 |
+
description = st.text_area("Décris ton bien immobilier : ", example_description)
|
41 |
+
|
42 |
+
|
43 |
+
if (len(description)>0) & (description != example_description):
|
44 |
+
predicted_price = predict([description], tokenizer, scaler, model, device,
|
45 |
+
max_len, 32, return_scaled=False)[0]
|
46 |
+
predicted_price_formatted = format_currency(predicted_price, 'EUR',
|
47 |
+
locale='fr_FR')
|
48 |
+
st.markdown('')
|
49 |
+
st.markdown('')
|
50 |
+
st.markdown('On estime que ton bien immobilier serait annoncé à :')
|
51 |
+
st.markdown("<h1 style='text-align: center;'>" \
|
52 |
+
+ predicted_price_formatted + "</h1>", unsafe_allow_html=True)
|