File size: 1,915 Bytes
7e703d7
 
 
 
 
 
 
 
 
 
402e31a
8051346
402e31a
 
 
7e703d7
402e31a
7e703d7
 
 
 
 
 
 
402e31a
b4400ab
7e703d7
 
402e31a
 
 
7e703d7
 
0024524
7e703d7
10b30d8
7e703d7
3fd203e
7e703d7
402e31a
7e703d7
 
 
 
402e31a
 
7e703d7
 
 
 
402e31a
7e703d7
402e31a
 
 
 
7e703d7
 
402e31a
7e703d7
402e31a
7e703d7
 
402e31a
7e703d7
 
402e31a
7e703d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import openai

from llama_index.query_engine.retriever_query_engine import RetrieverQueryEngine
from llama_index.callbacks.base import CallbackManager
from llama_index import (
    LLMPredictor,
    ServiceContext,
    StorageContext,
    load_index_from_storage,
)
from llama_index.llms import OpenAI
import chainlit as cl


openai.api_key = os.environ.get("OPENAI_API_KEY")

try:
    # rebuild storage context
    storage_context = StorageContext.from_defaults(persist_dir="./storage")
    # load index
    index = load_index_from_storage(storage_context)
except:
    from llama_index import GPTVectorStoreIndex, SimpleDirectoryReader

    documents = SimpleDirectoryReader(input_files=["hitchhikers.pdf"]).load_data()
    index = GPTVectorStoreIndex.from_documents(documents)
    index.storage_context.persist()


@cl.on_chat_start
async def factory():
    llm_predictor = LLMPredictor(
        llm=OpenAI(
            temperature=0,
            model_name="ft:gpt-3.5-turbo-0613:personal::7sckuHOj",
            streaming=True,
            context_window=2048,
        ),
    )
    service_context = ServiceContext.from_defaults(
        llm_predictor=llm_predictor,
        chunk_size=512,
        callback_manager=CallbackManager([cl.LlamaIndexCallbackHandler()]),
    )

    query_engine = index.as_query_engine(
        service_context=service_context,
        streaming=True,
    )

    cl.user_session.set("query_engine", query_engine)


@cl.on_message
async def main(message):
    query_engine = cl.user_session.get("query_engine")  # type: RetrieverQueryEngine
    response = await cl.make_async(query_engine.query)(message)

    response_message = cl.Message(content="")

    for token in response.response_gen:
        await response_message.stream_token(token=token)

    if response.response_txt:
        response_message.content = response.response_txt

    await response_message.send()