geekyrakshit commited on
Commit
8d64162
·
1 Parent(s): 8bd2693

add: LLMClient

Browse files
medrag_multi_modal/assistant/__init__.py ADDED
File without changes
medrag_multi_modal/assistant/llm_client.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from enum import Enum
3
+ from typing import Any, Optional, Union
4
+
5
+ import instructor
6
+ import weave
7
+ from PIL import Image
8
+
9
+ from ..utils import base64_encode_image
10
+
11
+
12
+ class ClientType(Enum, str):
13
+ GEMINI = "gemini"
14
+ MISTRAL = "mistral"
15
+
16
+
17
+ class LLMClient(weave.Model):
18
+ model_name: str
19
+ client_type: ClientType
20
+
21
+ def __init__(self, model_name: str, client_type: ClientType):
22
+ super().__init__(model_name=model_name, client_type=client_type)
23
+
24
+ @weave.op()
25
+ def execute_gemini_sdk(
26
+ self,
27
+ user_prompt: Union[str, list[str]],
28
+ system_prompt: Optional[Union[str, list[str]]] = None,
29
+ schema: Optional[Any] = None,
30
+ ) -> Union[str, Any]:
31
+ import google.generativeai as genai
32
+
33
+ genai.configure(api_key=os.environ.get("GOOGLE_API_KEY"))
34
+ model = genai.GenerativeModel(self.model_name, system_instruction=system_prompt)
35
+ generation_config = (
36
+ None
37
+ if schema is None
38
+ else genai.GenerationConfig(
39
+ response_mime_type="application/json", response_schema=list[schema]
40
+ )
41
+ )
42
+ response = model.generate_content(
43
+ user_prompt, generation_config=generation_config
44
+ )
45
+ return response.text if schema is None else response
46
+
47
+ @weave.op()
48
+ def execute_mistral_sdk(
49
+ self,
50
+ user_prompt: Union[str, list[str]],
51
+ system_prompt: Optional[Union[str, list[str]]] = None,
52
+ schema: Optional[Any] = None,
53
+ ) -> Union[str, Any]:
54
+ from mistralai import Mistral
55
+
56
+ system_prompt = (
57
+ [system_prompt] if isinstance(system_prompt, str) else system_prompt
58
+ )
59
+ user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt
60
+ messages = [{"type": "text", "text": prompt} for prompt in system_prompt]
61
+ for prompt in user_prompt:
62
+ if isinstance(prompt, Image.Image):
63
+ messages.append(
64
+ {
65
+ "type": "image_url",
66
+ "image_url": base64_encode_image(prompt, "image/png"),
67
+ }
68
+ )
69
+ else:
70
+ messages.append({"type": "text", "text": prompt})
71
+
72
+ client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY"))
73
+ client = instructor.from_mistral(client)
74
+
75
+ response = (
76
+ client.chat.complete(model=self.model_name, messages=messages)
77
+ if schema is None
78
+ else client.messages.create(
79
+ response_model=schema, messages=messages, temperature=0
80
+ )
81
+ )
82
+ return response.choices[0].message.content
83
+
84
+ @weave.op()
85
+ def predict(
86
+ self,
87
+ user_prompt: Union[str, list[str]],
88
+ system_prompt: Optional[Union[str, list[str]]] = None,
89
+ schema: Optional[Any] = None,
90
+ ) -> Union[str, Any]:
91
+ if self.client_type == ClientType.GEMINI:
92
+ return self.execute_gemini_sdk(user_prompt, system_prompt, schema)
93
+ elif self.client_type == ClientType.MISTRAL:
94
+ return self.execute_mistral_sdk(user_prompt, system_prompt, schema)
95
+ else:
96
+ raise ValueError(f"Invalid client type: {self.client_type}")
medrag_multi_modal/utils.py CHANGED
@@ -1,4 +1,8 @@
 
 
 
1
  import torch
 
2
 
3
  import wandb
4
 
@@ -29,3 +33,11 @@ def get_torch_backend():
29
  return "mps"
30
  return "cpu"
31
  return "cpu"
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+
4
  import torch
5
+ from PIL import Image
6
 
7
  import wandb
8
 
 
33
  return "mps"
34
  return "cpu"
35
  return "cpu"
36
+
37
+
38
+ def base64_encode_image(image: Image.Image, mimetype: str) -> str:
39
+ byte_arr = io.BytesIO()
40
+ image.save(byte_arr, format="PNG")
41
+ encoded_string = base64.b64encode(byte_arr.getvalue()).decode("utf-8")
42
+ encoded_string = f"data:{mimetype};base64,{encoded_string}"
43
+ return str(encoded_string)
pyproject.toml CHANGED
@@ -38,6 +38,9 @@ dependencies = [
38
  "semchunk>=2.2.0",
39
  "tiktoken>=0.8.0",
40
  "sentence-transformers>=3.2.0",
 
 
 
41
  ]
42
 
43
  [project.optional-dependencies]
@@ -61,6 +64,9 @@ core = [
61
  "torch>=2.4.1",
62
  "weave>=0.51.14",
63
  "sentence-transformers>=3.2.0",
 
 
 
64
  ]
65
 
66
  dev = ["pytest>=8.3.3", "isort>=5.13.2", "black>=24.10.0", "ruff>=0.6.9"]
 
38
  "semchunk>=2.2.0",
39
  "tiktoken>=0.8.0",
40
  "sentence-transformers>=3.2.0",
41
+ "google-generativeai>=0.8.3",
42
+ "mistralai>=1.1.0",
43
+ "instructor>=1.6.3",
44
  ]
45
 
46
  [project.optional-dependencies]
 
64
  "torch>=2.4.1",
65
  "weave>=0.51.14",
66
  "sentence-transformers>=3.2.0",
67
+ "google-generativeai>=0.8.3",
68
+ "mistralai>=1.1.0",
69
+ "instructor>=1.6.3",
70
  ]
71
 
72
  dev = ["pytest>=8.3.3", "isort>=5.13.2", "black>=24.10.0", "ruff>=0.6.9"]