File size: 3,623 Bytes
5c60ed2
f8adcff
0635997
 
7c12ef4
d5c54ef
0635997
 
f846748
 
 
cd7ca86
60ac7f7
 
597639f
 
60ac7f7
 
f846748
0635997
 
 
 
 
 
8661441
d5c54ef
f846748
 
 
 
 
 
 
c3ef985
f846748
5b9e4ac
f846748
d2eb5fb
 
c3ef985
7e5bae2
443f706
d2eb5fb
309b510
 
 
fc37aba
d5c54ef
 
d2eb5fb
 
 
d5c54ef
7e5bae2
 
9797856
 
 
e061cc2
f2e8f60
9797856
7e5bae2
1cafcb9
a37b742
be65967
34f414b
 
 
 
 
 
a2933d7
f846748
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1fef0d
d2eb5fb
f846748
 
 
 
 
 
 
 
c3ef985
f846748
cddcba8
f846748
 
 
 
cdec1a7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
from huggingface_hub import InferenceClient, login, snapshot_download
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
import os
import pandas as pd


"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co./docs/huggingface_hub/v0.22.2/en/guides/inference
"""
login(token=os.getenv('TOKEN'))

#model = "meta-llama/Llama-3.2-1B-Instruct"
#model = "mistralai/Mistral-7B-Instruct-v0.3"
model = "google/mt5-small"

client = InferenceClient(model)

folder = snapshot_download(repo_id="umaiku/faiss_index", repo_type="dataset", local_dir=os.getcwd())

embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-small")

vector_db = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True)

df = pd.read_csv("faiss_index/bger_cedh_db 1954-2024.csv")

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
    score,
):
    messages = [{"role": "system", "content": system_message}]

    print(system_message)
    
    retriever = vector_db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": score})
    documents = retriever.invoke(message)

    spacer = " \n"

    context = ""

    for doc in documents[:3]:
        case_text = df[df["case_url"] == doc.metadata["case_url"]].case_text.values[0]
        
        context += "Case number: " + doc.metadata["case_nb"] + spacer
        context += "Case date: " + doc.metadata["case_date"] + spacer
        context += "Case url: " + doc.metadata["case_url"] + spacer
        context += "Case text: " + case_text + spacer
    
    message = f"""
A user is asking you the following question: {message}
Please answer the user in the same language that he used in his question.
Use the following context collected from various Swiss federal jurisprudence cases:
{context}
Please mention your sources in your answer, including the urls and dates.
Always answer the user using the language used in his question which was: {message}
    """
    
    print(message)

#    for val in history:
#        if val[0]:
#            messages.append({"role": "user", "content": val[0]})
#        if val[1]:
#            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    response = ""

    for message in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        token = message.choices[0].delta.content

        response += token
        yield response


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are an assistant in Swiss Jurisprudence cases.", label="System message"),
        gr.Slider(minimum=1, maximum=24000, value=5000, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
        gr.Slider(minimum=0, maximum=1, value=0.7, step=0.1, label="Score Threshold"),
    ],
    description="# 📜 ALexI: Artificial Legal Intelligence for Swiss Jurisprudence",
)


if __name__ == "__main__":
    demo.launch(debug=True)