File size: 6,441 Bytes
9c3709d
 
 
 
 
 
4c66227
d803be1
9c3709d
8c28786
9c3709d
9847233
d803be1
7e9684b
7402de3
9c3709d
 
 
 
 
 
8c28786
9c3709d
 
d803be1
6f80de5
4c66227
6f80de5
d803be1
7e9684b
7402de3
9c3709d
 
 
 
 
 
8c28786
9c3709d
 
 
8d1e83e
d803be1
9c3709d
 
 
 
d594a38
 
8c28786
d803be1
9c3709d
 
 
 
542890e
 
 
6f80de5
542890e
 
6f80de5
542890e
 
 
 
 
 
 
 
6f80de5
 
 
 
 
 
542890e
9c3709d
8d1e83e
9c3709d
d594a38
9c3709d
d803be1
 
7402de3
8c28786
4c66227
d803be1
9c3709d
8d1e83e
d803be1
9847233
7e9684b
d803be1
9c3709d
8d1e83e
d803be1
6f80de5
ca913e4
 
6f80de5
 
8d1e83e
7402de3
e5a770a
 
 
 
7402de3
9c3709d
d803be1
9847233
 
8d1e83e
 
 
 
 
d594a38
9c3709d
 
8d1e83e
 
 
d803be1
9c3709d
 
9847233
df527c8
9c3709d
9847233
d803be1
9c3709d
d803be1
7e9684b
8c28786
7e9684b
8c28786
9c3709d
8c28786
 
 
d803be1
8c28786
9847233
8c28786
 
 
 
 
 
 
d803be1
8c28786
9847233
8c28786
 
 
 
 
d803be1
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""search_agent.py

Usage:
    search_agent.py 
        [--domain=domain]
        [--provider=provider]
        [--model=model]
        [--embedding_model=model]
        [--temperature=temp]
        [--copywrite]
        [--max_pages=num]
        [--max_extracts=num]
        [--use_selenium]
        [--output=text]
        [--verbose]
        SEARCH_QUERY
    search_agent.py --version

Options:
    -h --help                           Show this screen.
    --version                           Show version.
    -c --copywrite                      First produce a draft, review it and rewrite for a final text
    -d domain --domain=domain           Limit search to a specific domain
    -t temp --temperature=temp          Set the temperature of the LLM [default: 0.0]
    -m model --model=model              Use a specific model [default: openai/gpt-4o-mini]
    -e model --embedding_model=model    Use a specific embedding model [default: same provider as model]
    -n num --max_pages=num              Max number of pages to retrieve [default: 10]
    -x num --max_extracts=num           Max number of page extract to consider [default: 7]
    -s --use_selenium                   Use selenium to fetch content from the web [default: False]
    -o text --output=text               Output format (choices: text, markdown) [default: markdown]
    -v --verbose                        Print verbose output [default: False]

"""

import os

from docopt import docopt
#from schema import Schema, Use, SchemaError
import dotenv

from langchain.callbacks import LangChainTracer

from langsmith import Client, traceable

from rich.console import Console
from rich.markdown import Markdown

import web_rag as wr
import web_crawler as wc
import copywriter as cw
import models as md

console = Console()
dotenv.load_dotenv()

def get_selenium_driver():
    from selenium import webdriver
    from selenium.webdriver.chrome.options import Options
    from selenium.common.exceptions import WebDriverException

    chrome_options = Options()
    chrome_options.add_argument("--headless")
    chrome_options.add_argument("--disable-extensions")
    chrome_options.add_argument("--disable-gpu")
    chrome_options.add_argument("--no-sandbox")
    chrome_options.add_argument("--disable-dev-shm-usage")
    chrome_options.add_argument("--remote-debugging-port=9222")
    chrome_options.add_argument('--blink-settings=imagesEnabled=false')
    chrome_options.add_argument("--window-size=1920,1080")

    try:
        driver = webdriver.Chrome(options=chrome_options)
        return driver
    except WebDriverException as e:
        print(f"Error creating Selenium WebDriver: {e}")
        return None

callbacks = []
if os.getenv("LANGCHAIN_API_KEY"):
    callbacks.append(
        LangChainTracer(client=Client())
    )
@traceable(run_type="tool", name="search_agent")
def main(arguments):
    verbose = arguments["--verbose"]
    copywrite_mode = arguments["--copywrite"]
    model = arguments["--model"]
    embedding_model = arguments["--embedding_model"]
    temperature = float(arguments["--temperature"])
    domain=arguments["--domain"]
    max_pages=int(arguments["--max_pages"])
    max_extract=int(arguments["--max_extracts"])
    output=arguments["--output"]
    use_selenium=arguments["--use_selenium"]
    query = arguments["SEARCH_QUERY"]

    chat = md.get_model(model, temperature)
    if embedding_model.lower() == "same provider as model":
        provider = model.split(':')[0]
        embedding_model = md.get_embedding_model(f"{provider}")
    else:
        embedding_model = md.get_embedding_model(embedding_model)

    if verbose:
        model_name = getattr(chat, 'model_name', None) or getattr(chat, 'model', None) or getattr(chat, 'model_id', None) or str(chat)
        embedding_model_name = getattr(embedding_model, 'model_name', None) or getattr(embedding_model, 'model', None) or getattr(embedding_model, 'model_id', None) or str(embedding_model)
        console.log(f"Using model: {model_name}")
        console.log(f"Using embedding model: {embedding_model_name}")

    with console.status(f"[bold green]Optimizing query for search: {query}"):
        optimize_search_query = wr.optimize_search_query(chat, query)
        if len(optimize_search_query) < 3:
            optimize_search_query = query
    console.log(f"Optimized search query: [bold blue]{optimize_search_query}")

    with console.status(
            f"[bold green]Searching sources using the optimized query: {optimize_search_query}"
        ):
        sources = wc.get_sources(optimize_search_query, max_pages=max_pages, domain=domain)
    console.log(f"Found {len(sources)} sources {'on ' + domain if domain else ''}")

    with console.status(
        f"[bold green]Fetching content for {len(sources)} sources", spinner="growVertical"
    ):
        contents = wc.get_links_contents(sources, get_selenium_driver, use_selenium=use_selenium)
    console.log(f"Managed to extract content from {len(contents)} sources")

    with console.status(f"[bold green]Embedding {len(contents)} sources for content", spinner="growVertical"):
        vector_store = wc.vectorize(contents, embedding_model)

    with console.status("[bold green]Writing content", spinner='dots8Bit'):
        draft = wr.query_rag(chat, query, optimize_search_query, vector_store, top_k = max_extract)

    console.rule(f"[bold green]Response")
    if output == "text":
        console.print(draft)
    else:
        console.print(Markdown(draft))
    console.rule("[bold green]")
    
    if(copywrite_mode):
        with console.status("[bold green]Getting comments from the reviewer", spinner="dots8Bit"):
            comments = cw.generate_comments(chat, query, draft)

        console.rule("[bold green]Response from reviewer")
        if output == "text":
            console.print(comments)
        else:
            console.print(Markdown(comments))
        console.rule("[bold green]")

        with console.status("[bold green]Writing the final text", spinner="dots8Bit"):
            final_text = cw.generate_final_text(chat, query, draft, comments)

        console.rule("[bold green]Final text")
        if output == "text":
            console.print(final_text)
        else:
            console.print(Markdown(final_text))
        console.rule("[bold green]")

if __name__ == '__main__':
    arguments = docopt(__doc__, version='Search Agent 0.1')
    main(arguments)