Omar Solano
commited on
Commit
Β·
377744c
1
Parent(s):
3f59041
update scraping scripts
Browse files- data/scraping/huggingface_docs/parse_hf_html.py +0 -166
- data/scraping/huggingface_docs/scrape_hf_docs_from_repo.py +0 -57
- data/scraping/huggingface_docs/scrape_hf_docs_from_web.py +0 -134
- data/scraping/huggingface_docs/validate_jsonl.py +0 -51
- data/scraping_scripts/create_db.ipynb +389 -0
- data/scraping_scripts/create_jsonl_file_hf.py +154 -0
- data/scraping_scripts/create_jsonl_file_llama.py +123 -0
- data/scraping_scripts/get_md_files_from_repo.py +137 -0
- scripts/call_openai.py +0 -79
- scripts/create_db.ipynb +0 -0
data/scraping/huggingface_docs/parse_hf_html.py
DELETED
@@ -1,166 +0,0 @@
|
|
1 |
-
import io
|
2 |
-
import json
|
3 |
-
import os
|
4 |
-
from pathlib import Path
|
5 |
-
from urllib.parse import urljoin
|
6 |
-
|
7 |
-
import pandas as pd
|
8 |
-
from bs4 import BeautifulSoup
|
9 |
-
from tqdm import tqdm
|
10 |
-
|
11 |
-
|
12 |
-
class HuggingfaceParser:
|
13 |
-
def __init__(self, html, url):
|
14 |
-
self.soup = BeautifulSoup(html, "html.parser")
|
15 |
-
self.url = url
|
16 |
-
|
17 |
-
def find_sections(self):
|
18 |
-
sections = []
|
19 |
-
main_content = self.soup.find("article", class_="md-content__inner")
|
20 |
-
if not main_content:
|
21 |
-
main_content = self.soup.find(
|
22 |
-
"div", class_="main-container"
|
23 |
-
) # Look for main container
|
24 |
-
if not main_content:
|
25 |
-
main_content = self.soup.find(
|
26 |
-
"body"
|
27 |
-
) # Fallback to body if nothing else found
|
28 |
-
|
29 |
-
if not main_content:
|
30 |
-
print(f"Error: No main content found for {self.url}")
|
31 |
-
return sections
|
32 |
-
|
33 |
-
# Try to find headers
|
34 |
-
headers = main_content.find_all(["h1", "h2", "h3", "h4", "h5", "h6"])
|
35 |
-
|
36 |
-
if not headers:
|
37 |
-
# If no headers, look for other structural elements
|
38 |
-
headers = main_content.find_all(
|
39 |
-
["div", "p"], class_=["docstring", "section"]
|
40 |
-
)
|
41 |
-
|
42 |
-
if not headers:
|
43 |
-
print(f"Warning: No headers or sections found in {self.url}")
|
44 |
-
# If still no headers, treat the whole content as one section
|
45 |
-
title = self.soup.title.string if self.soup.title else "Untitled"
|
46 |
-
sections.append(
|
47 |
-
{
|
48 |
-
"name": title,
|
49 |
-
"url": self.url,
|
50 |
-
"content": main_content.get_text(strip=True),
|
51 |
-
"level": 1,
|
52 |
-
}
|
53 |
-
)
|
54 |
-
return sections
|
55 |
-
|
56 |
-
for i, header in enumerate(headers):
|
57 |
-
name = header.text.strip()
|
58 |
-
header_id = header.get("id", "")
|
59 |
-
if header_id:
|
60 |
-
section_url = f"{self.url}#{header_id}"
|
61 |
-
else:
|
62 |
-
section_url = self.url
|
63 |
-
|
64 |
-
content = self.extract_content(
|
65 |
-
header, headers[i + 1] if i + 1 < len(headers) else None
|
66 |
-
)
|
67 |
-
sections.append(
|
68 |
-
{
|
69 |
-
"name": name,
|
70 |
-
"url": section_url,
|
71 |
-
"content": content,
|
72 |
-
"level": self.get_header_level(header),
|
73 |
-
}
|
74 |
-
)
|
75 |
-
|
76 |
-
return sections
|
77 |
-
|
78 |
-
def extract_content(self, start_tag, end_tag):
|
79 |
-
content = []
|
80 |
-
current = start_tag.next_sibling
|
81 |
-
while current and current != end_tag:
|
82 |
-
if isinstance(current, str):
|
83 |
-
content.append(current.strip())
|
84 |
-
elif current.name == "table":
|
85 |
-
table_html = io.StringIO(str(current))
|
86 |
-
content.append(
|
87 |
-
pd.read_html(table_html)[0].to_markdown(
|
88 |
-
index=False, tablefmt="github"
|
89 |
-
)
|
90 |
-
)
|
91 |
-
elif current.name not in ["script", "style"]:
|
92 |
-
content.append(current.get_text(strip=True, separator=" "))
|
93 |
-
current = current.next_sibling
|
94 |
-
return "\n".join(filter(None, content))
|
95 |
-
|
96 |
-
def get_header_level(self, tag):
|
97 |
-
if tag.name in ["h1", "h2", "h3", "h4", "h5", "h6"]:
|
98 |
-
return int(tag.name[1])
|
99 |
-
elif "class" in tag.attrs:
|
100 |
-
if "docstring" in tag["class"]:
|
101 |
-
return 1
|
102 |
-
elif "section" in tag["class"]:
|
103 |
-
return 2
|
104 |
-
return 1 # Default level
|
105 |
-
|
106 |
-
|
107 |
-
def is_likely_html_file(file_path):
|
108 |
-
excluded_extensions = {".css", ".js", ".png", ".jpg", ".jpeg", ".gif", ".svg"}
|
109 |
-
return file_path.suffix == "" or file_path.suffix.lower() not in excluded_extensions
|
110 |
-
|
111 |
-
|
112 |
-
def parse_saved_html_files(html_dir, base_url):
|
113 |
-
all_sections = []
|
114 |
-
html_files = [
|
115 |
-
f for f in Path(html_dir).rglob("*") if f.is_file() and is_likely_html_file(f)
|
116 |
-
]
|
117 |
-
print(f"Found {len(html_files)} HTML files")
|
118 |
-
|
119 |
-
for html_file in tqdm(html_files, desc="Parsing HTML files"):
|
120 |
-
try:
|
121 |
-
with open(html_file, "r", encoding="utf-8") as file:
|
122 |
-
html_content = file.read()
|
123 |
-
|
124 |
-
relative_path = html_file.relative_to(html_dir)
|
125 |
-
url = urljoin(base_url, str(relative_path).replace(os.path.sep, "/"))
|
126 |
-
|
127 |
-
parser = HuggingfaceParser(html_content, url)
|
128 |
-
sections = parser.find_sections()
|
129 |
-
|
130 |
-
if not sections:
|
131 |
-
print(f"Warning: No sections found in {html_file}")
|
132 |
-
# exit(0)
|
133 |
-
# break
|
134 |
-
all_sections.extend(sections)
|
135 |
-
except Exception as e:
|
136 |
-
print(f"Error parsing {html_file}: {str(e)}")
|
137 |
-
# exit(0)
|
138 |
-
|
139 |
-
return all_sections
|
140 |
-
|
141 |
-
|
142 |
-
def save_to_jsonl(data, output_file):
|
143 |
-
with open(output_file, "w", encoding="utf-8") as f:
|
144 |
-
for item in data:
|
145 |
-
json.dump(item, f, ensure_ascii=False)
|
146 |
-
f.write("\n")
|
147 |
-
|
148 |
-
|
149 |
-
def main():
|
150 |
-
# html_dir = "transformers_docs_v4.42.0" # Directory where HTML files are saved
|
151 |
-
# base_url = "https://huggingface.co/docs/transformers/"
|
152 |
-
|
153 |
-
html_dir = "peft_docs_v0.11.0" # Directory where HTML files are saved
|
154 |
-
base_url = "https://huggingface.co/docs/peft/"
|
155 |
-
|
156 |
-
output_file = "hf_peft_v0_11_0.jsonl"
|
157 |
-
|
158 |
-
all_sections = parse_saved_html_files(html_dir, base_url)
|
159 |
-
save_to_jsonl(all_sections, output_file)
|
160 |
-
|
161 |
-
print(f"Parsed content saved to {output_file}")
|
162 |
-
print(f"Total sections parsed: {len(all_sections)}")
|
163 |
-
|
164 |
-
|
165 |
-
if __name__ == "__main__":
|
166 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/scraping/huggingface_docs/scrape_hf_docs_from_repo.py
DELETED
@@ -1,57 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
|
3 |
-
import requests
|
4 |
-
|
5 |
-
# GitHub repository information
|
6 |
-
owner = "huggingface"
|
7 |
-
|
8 |
-
# repo = "peft"
|
9 |
-
# path = "docs/source"
|
10 |
-
|
11 |
-
repo = "transformers"
|
12 |
-
path = "docs/source/en"
|
13 |
-
|
14 |
-
# GitHub API endpoint for the repository contents
|
15 |
-
api_url = f"https://api.github.com/repos/{owner}/{repo}/contents/{path}"
|
16 |
-
|
17 |
-
|
18 |
-
def get_files_in_directory(api_url):
|
19 |
-
response = requests.get(api_url)
|
20 |
-
if response.status_code == 200:
|
21 |
-
return response.json()
|
22 |
-
else:
|
23 |
-
print(f"Failed to fetch directory contents: {response.status_code}")
|
24 |
-
return []
|
25 |
-
|
26 |
-
|
27 |
-
def download_file(file_url, file_path):
|
28 |
-
response = requests.get(file_url)
|
29 |
-
if response.status_code == 200:
|
30 |
-
with open(file_path, "wb") as file:
|
31 |
-
file.write(response.content)
|
32 |
-
else:
|
33 |
-
print(f"Failed to download file: {response.status_code}")
|
34 |
-
|
35 |
-
|
36 |
-
def fetch_md_files(api_url, local_dir):
|
37 |
-
files = get_files_in_directory(api_url)
|
38 |
-
for file in files:
|
39 |
-
if file["type"] == "file" and file["name"].endswith(".md"):
|
40 |
-
file_url = file["download_url"]
|
41 |
-
file_path = os.path.join(local_dir, file["name"])
|
42 |
-
print(f'Downloading {file["name"]}...')
|
43 |
-
download_file(file_url, file_path)
|
44 |
-
elif file["type"] == "dir":
|
45 |
-
subdir = os.path.join(local_dir, file["name"])
|
46 |
-
os.makedirs(subdir, exist_ok=True)
|
47 |
-
fetch_md_files(file["url"], subdir)
|
48 |
-
|
49 |
-
|
50 |
-
# Local directory to save the files
|
51 |
-
local_dir = f"data/{repo}_docs"
|
52 |
-
os.makedirs(local_dir, exist_ok=True)
|
53 |
-
|
54 |
-
# Start fetching files
|
55 |
-
fetch_md_files(api_url, local_dir)
|
56 |
-
|
57 |
-
print("All files have been downloaded.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/scraping/huggingface_docs/scrape_hf_docs_from_web.py
DELETED
@@ -1,134 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
from pathlib import Path
|
3 |
-
from urllib.parse import unquote, urljoin, urlparse
|
4 |
-
|
5 |
-
import scrapy
|
6 |
-
from scrapy.crawler import CrawlerProcess
|
7 |
-
from tqdm import tqdm
|
8 |
-
|
9 |
-
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.INFO)
|
10 |
-
|
11 |
-
|
12 |
-
def is_valid_url(url, domain, base_path):
|
13 |
-
parsed = urlparse(url)
|
14 |
-
return (
|
15 |
-
parsed.scheme in ["http", "https"]
|
16 |
-
and parsed.netloc == domain
|
17 |
-
and parsed.path.startswith(base_path)
|
18 |
-
and "#" not in url
|
19 |
-
) # Exclude URLs with fragments
|
20 |
-
|
21 |
-
|
22 |
-
def clean_url(url):
|
23 |
-
# Replace & with &, and # with #
|
24 |
-
url = url.replace("&", "&").replace("#", "#")
|
25 |
-
# Decode URL-encoded characters
|
26 |
-
return unquote(url)
|
27 |
-
|
28 |
-
|
29 |
-
class DocsSpider(scrapy.Spider):
|
30 |
-
name = "docs"
|
31 |
-
|
32 |
-
def __init__(
|
33 |
-
self,
|
34 |
-
homepage_url: str,
|
35 |
-
domain: str,
|
36 |
-
base_path: str,
|
37 |
-
save_dir="outputs/",
|
38 |
-
target_version=None,
|
39 |
-
*args,
|
40 |
-
**kwargs,
|
41 |
-
):
|
42 |
-
super(DocsSpider, self).__init__(*args, **kwargs)
|
43 |
-
self.homepage_url = homepage_url
|
44 |
-
self.domain = domain
|
45 |
-
self.base_path = base_path
|
46 |
-
self.allowed_domains = [domain]
|
47 |
-
self.start_urls = [self.homepage_url]
|
48 |
-
self.base_dir = Path(save_dir)
|
49 |
-
self.target_version = target_version
|
50 |
-
self.pages = []
|
51 |
-
self.progress_bar = None
|
52 |
-
|
53 |
-
def start_requests(self):
|
54 |
-
self.progress_bar = tqdm(desc="Crawling pages", unit="page")
|
55 |
-
yield scrapy.Request(self.homepage_url, self.parse)
|
56 |
-
|
57 |
-
def parse(self, response):
|
58 |
-
if not is_valid_url(response.url, self.domain, self.base_path):
|
59 |
-
return
|
60 |
-
|
61 |
-
parsed_uri = urlparse(response.url)
|
62 |
-
relative_path = parsed_uri.path.removeprefix(self.base_path).strip("/")
|
63 |
-
if relative_path:
|
64 |
-
filepath = self.base_dir / relative_path
|
65 |
-
else:
|
66 |
-
filepath = self.base_dir / "index.html"
|
67 |
-
|
68 |
-
filepath.parent.mkdir(parents=True, exist_ok=True)
|
69 |
-
with open(filepath, "wb") as f:
|
70 |
-
f.write(response.body)
|
71 |
-
|
72 |
-
self.pages.append({"url": response.url, "html": response.body})
|
73 |
-
# if self.progress_bar:
|
74 |
-
self.progress_bar.update(1)
|
75 |
-
|
76 |
-
for href in response.css("a::attr(href)").getall():
|
77 |
-
full_url = response.urljoin(clean_url(href))
|
78 |
-
if is_valid_url(full_url, self.domain, self.base_path):
|
79 |
-
if self.target_version:
|
80 |
-
if self.target_version in full_url:
|
81 |
-
yield response.follow(full_url, self.parse)
|
82 |
-
else:
|
83 |
-
yield response.follow(full_url, self.parse)
|
84 |
-
|
85 |
-
def closed(self, reason):
|
86 |
-
if self.progress_bar:
|
87 |
-
self.progress_bar.close()
|
88 |
-
|
89 |
-
|
90 |
-
def crawl_docs(start_url, domain, base_path, save_dir="outputs/", target_version=None):
|
91 |
-
process = CrawlerProcess(
|
92 |
-
settings={
|
93 |
-
"USER_AGENT": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3",
|
94 |
-
"DOWNLOAD_DELAY": 2,
|
95 |
-
"RANDOMIZE_DOWNLOAD_DELAY": True,
|
96 |
-
"CONCURRENT_REQUESTS": 1,
|
97 |
-
"RETRY_TIMES": 5,
|
98 |
-
"RETRY_HTTP_CODES": [429, 500, 502, 503, 504, 522, 524, 408, 400],
|
99 |
-
"HTTPERROR_ALLOWED_CODES": [404], # Allow 404 errors to be logged
|
100 |
-
}
|
101 |
-
)
|
102 |
-
|
103 |
-
process.crawl(
|
104 |
-
DocsSpider,
|
105 |
-
homepage_url=start_url,
|
106 |
-
domain=domain,
|
107 |
-
base_path=base_path,
|
108 |
-
save_dir=save_dir,
|
109 |
-
target_version=target_version,
|
110 |
-
)
|
111 |
-
process.start()
|
112 |
-
|
113 |
-
spider = next(s for s in process.crawlers if s.spider.name == "docs").spider
|
114 |
-
|
115 |
-
print(f"Total pages crawled and parsed: {len(spider.pages)}")
|
116 |
-
|
117 |
-
|
118 |
-
if __name__ == "__main__":
|
119 |
-
# https://huggingface.co/docs/peft/v0.11.0/en/index
|
120 |
-
# Customizable parameters
|
121 |
-
domain = "huggingface.co"
|
122 |
-
version = "v0.11.0"
|
123 |
-
library = "peft"
|
124 |
-
language = "en"
|
125 |
-
|
126 |
-
# Construct URL and paths
|
127 |
-
base_path = f"/docs/{library}/{version}/{language}"
|
128 |
-
start_url = f"https://{domain}{base_path}/index"
|
129 |
-
save_dir = f"{library}_docs_{version}"
|
130 |
-
|
131 |
-
# Optional: Set target_version to None if you want to crawl all versions
|
132 |
-
target_version = None
|
133 |
-
|
134 |
-
crawl_docs(start_url, domain, base_path, save_dir, target_version)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/scraping/huggingface_docs/validate_jsonl.py
DELETED
@@ -1,51 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
from typing import Any, Dict, List
|
3 |
-
|
4 |
-
|
5 |
-
def load_and_validate_jsonl(file_path: str) -> Dict[int, Any]:
|
6 |
-
"""
|
7 |
-
Load a .jsonl file into a dictionary and validate each line.
|
8 |
-
|
9 |
-
Args:
|
10 |
-
file_path (str): Path to the .jsonl file
|
11 |
-
|
12 |
-
Returns:
|
13 |
-
Dict[int, Any]: A dictionary where keys are line numbers (1-indexed) and values are the parsed JSON objects
|
14 |
-
|
15 |
-
Raises:
|
16 |
-
ValueError: If any line in the file is not valid JSON
|
17 |
-
"""
|
18 |
-
result = {}
|
19 |
-
with open(file_path, "r") as file:
|
20 |
-
for line_number, line in enumerate(file, 1):
|
21 |
-
try:
|
22 |
-
# Strip whitespace and check if the line is empty
|
23 |
-
stripped_line = line.strip()
|
24 |
-
if not stripped_line:
|
25 |
-
print(f"Warning: Line {line_number} is empty.")
|
26 |
-
continue
|
27 |
-
|
28 |
-
# Attempt to parse the JSON
|
29 |
-
parsed_json = json.loads(stripped_line)
|
30 |
-
result[line_number] = parsed_json
|
31 |
-
except json.JSONDecodeError as e:
|
32 |
-
raise ValueError(f"Invalid JSON on line {line_number}: {e}")
|
33 |
-
|
34 |
-
return result
|
35 |
-
|
36 |
-
|
37 |
-
if __name__ == "__main__":
|
38 |
-
file_path = "hf_transformers_v4_42_0.jsonl"
|
39 |
-
try:
|
40 |
-
loaded_data = load_and_validate_jsonl(file_path)
|
41 |
-
print(f"Successfully loaded {len(loaded_data)} valid JSON objects.")
|
42 |
-
|
43 |
-
# Optional: Print the first few items
|
44 |
-
print("\nFirst few items:")
|
45 |
-
for line_number, data in list(loaded_data.items())[:5]:
|
46 |
-
print(f"Line {line_number}: {data}")
|
47 |
-
|
48 |
-
except ValueError as e:
|
49 |
-
print(f"Error: {e}")
|
50 |
-
except FileNotFoundError:
|
51 |
-
print(f"Error: File '{file_path}' not found.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/scraping_scripts/create_db.ipynb
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# Create HF vector database\n"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "code",
|
12 |
+
"execution_count": 1,
|
13 |
+
"metadata": {},
|
14 |
+
"outputs": [
|
15 |
+
{
|
16 |
+
"data": {
|
17 |
+
"text/plain": [
|
18 |
+
"True"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
"execution_count": 1,
|
22 |
+
"metadata": {},
|
23 |
+
"output_type": "execute_result"
|
24 |
+
}
|
25 |
+
],
|
26 |
+
"source": [
|
27 |
+
"from dotenv import load_dotenv\n",
|
28 |
+
"\n",
|
29 |
+
"load_dotenv(\"../../.env\")"
|
30 |
+
]
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"cell_type": "code",
|
34 |
+
"execution_count": 2,
|
35 |
+
"metadata": {},
|
36 |
+
"outputs": [],
|
37 |
+
"source": [
|
38 |
+
"import nest_asyncio\n",
|
39 |
+
"\n",
|
40 |
+
"nest_asyncio.apply()"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "markdown",
|
45 |
+
"metadata": {},
|
46 |
+
"source": [
|
47 |
+
"### Create a set of Llama-index Documents with each section in the jsonl file\n"
|
48 |
+
]
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"cell_type": "code",
|
52 |
+
"execution_count": 3,
|
53 |
+
"metadata": {},
|
54 |
+
"outputs": [
|
55 |
+
{
|
56 |
+
"name": "stdout",
|
57 |
+
"output_type": "stream",
|
58 |
+
"text": [
|
59 |
+
"Doc ID: 682dbc3b-96ff-4ca4-a556-44d3cd8ffa8a\n",
|
60 |
+
"Text: # Command Line Interfaces (CLIs) You can use TRL to fine-tune\n",
|
61 |
+
"your Language Model with Supervised Fine-Tuning (SFT) or Direct Policy\n",
|
62 |
+
"Optimization (DPO) or even chat with your model using the TRL CLIs.\n",
|
63 |
+
"Currently supported CLIs are: - `trl sft`: fine-tune a LLM on a\n",
|
64 |
+
"text/instruction dataset - `trl dpo`: fine-tune a LLM with DPO on a\n",
|
65 |
+
"preference ...\n",
|
66 |
+
"{'url': 'https://huggingface.co/docs/trl/clis/', 'title': 'Command Line Interfaces (CLIs)', 'tokens': 1209, 'retrieve_doc': True, 'source': 'TRL'}\n"
|
67 |
+
]
|
68 |
+
}
|
69 |
+
],
|
70 |
+
"source": [
|
71 |
+
"from llama_index.core import Document\n",
|
72 |
+
"from llama_index.core.schema import MetadataMode\n",
|
73 |
+
"import json\n",
|
74 |
+
"import pickle\n",
|
75 |
+
"\n",
|
76 |
+
"\n",
|
77 |
+
"def create_docs(input_file):\n",
|
78 |
+
" with open(input_file, \"r\") as f:\n",
|
79 |
+
" documents = []\n",
|
80 |
+
" for i, line in enumerate(f):\n",
|
81 |
+
" data = json.loads(line)\n",
|
82 |
+
" documents.append(\n",
|
83 |
+
" Document(\n",
|
84 |
+
" doc_id=data[\"doc_id\"],\n",
|
85 |
+
" text=data[\"content\"],\n",
|
86 |
+
" metadata={\n",
|
87 |
+
" \"url\": data[\"url\"],\n",
|
88 |
+
" \"title\": data[\"name\"],\n",
|
89 |
+
" \"tokens\": data[\"tokens\"],\n",
|
90 |
+
" \"retrieve_doc\": data[\"retrieve_doc\"],\n",
|
91 |
+
" \"source\": data[\"source\"],\n",
|
92 |
+
" },\n",
|
93 |
+
" # LLM will see the 'url' of each chunk\n",
|
94 |
+
" excluded_llm_metadata_keys=[\n",
|
95 |
+
" # \"url\",\n",
|
96 |
+
" \"title\",\n",
|
97 |
+
" \"tokens\",\n",
|
98 |
+
" \"retrieve_doc\",\n",
|
99 |
+
" \"source\",\n",
|
100 |
+
" ],\n",
|
101 |
+
" # Embedding model will embed the 'title' of each chunk\n",
|
102 |
+
" excluded_embed_metadata_keys=[\n",
|
103 |
+
" \"url\",\n",
|
104 |
+
" # \"title\",\n",
|
105 |
+
" \"tokens\",\n",
|
106 |
+
" \"retrieve_doc\",\n",
|
107 |
+
" \"source\",\n",
|
108 |
+
" ],\n",
|
109 |
+
" )\n",
|
110 |
+
" )\n",
|
111 |
+
" return documents\n",
|
112 |
+
"\n",
|
113 |
+
"\n",
|
114 |
+
"# documents = create_docs(\"../transformers_data.jsonl\")\n",
|
115 |
+
"# documents = create_docs(\"../peft_data.jsonl\")\n",
|
116 |
+
"documents = create_docs(\"../trl_data.jsonl\")\n",
|
117 |
+
"# documents = create_docs(\"../llama_index_data.jsonl\")\n",
|
118 |
+
"print(documents[0])\n",
|
119 |
+
"print(documents[0].metadata)"
|
120 |
+
]
|
121 |
+
},
|
122 |
+
{
|
123 |
+
"cell_type": "code",
|
124 |
+
"execution_count": null,
|
125 |
+
"metadata": {},
|
126 |
+
"outputs": [],
|
127 |
+
"source": [
|
128 |
+
"# print(\n",
|
129 |
+
"# \"The LLM sees this: \\n\",\n",
|
130 |
+
"# documents[0].get_content(metadata_mode=MetadataMode.LLM),\n",
|
131 |
+
"# )\n",
|
132 |
+
"print(\n",
|
133 |
+
" \"The Embedding model sees this: \\n\",\n",
|
134 |
+
" documents[0].get_content(metadata_mode=MetadataMode.EMBED),\n",
|
135 |
+
")"
|
136 |
+
]
|
137 |
+
},
|
138 |
+
{
|
139 |
+
"cell_type": "code",
|
140 |
+
"execution_count": 4,
|
141 |
+
"metadata": {},
|
142 |
+
"outputs": [],
|
143 |
+
"source": [
|
144 |
+
"import chromadb\n",
|
145 |
+
"\n",
|
146 |
+
"# create client and a new collection\n",
|
147 |
+
"DB_COLLECTION = \"chroma-db-trl\"\n",
|
148 |
+
"chroma_client = chromadb.PersistentClient(path=f\"../{DB_COLLECTION}\")\n",
|
149 |
+
"chroma_collection = chroma_client.create_collection(DB_COLLECTION)\n",
|
150 |
+
"\n",
|
151 |
+
"\n",
|
152 |
+
"from llama_index.vector_stores.chroma import ChromaVectorStore\n",
|
153 |
+
"from llama_index.core import StorageContext\n",
|
154 |
+
"\n",
|
155 |
+
"# Define a storage context object using the created vector database.\n",
|
156 |
+
"vector_store = ChromaVectorStore(chroma_collection=chroma_collection)\n",
|
157 |
+
"storage_context = StorageContext.from_defaults(vector_store=vector_store)\n",
|
158 |
+
"\n",
|
159 |
+
"document_dict = {doc.doc_id: doc for doc in documents}\n",
|
160 |
+
"DOCUMENT_NAME = f\"../{DB_COLLECTION}/document_dict_trl.pkl\"\n",
|
161 |
+
"\n",
|
162 |
+
"with open(DOCUMENT_NAME, \"wb\") as f:\n",
|
163 |
+
" pickle.dump(document_dict, f)\n",
|
164 |
+
"\n",
|
165 |
+
"# with open(DOCUMENT_NAME, \"rb\") as f:\n",
|
166 |
+
"# document_dict = pickle.load(f)"
|
167 |
+
]
|
168 |
+
},
|
169 |
+
{
|
170 |
+
"cell_type": "code",
|
171 |
+
"execution_count": 5,
|
172 |
+
"metadata": {},
|
173 |
+
"outputs": [
|
174 |
+
{
|
175 |
+
"name": "stderr",
|
176 |
+
"output_type": "stream",
|
177 |
+
"text": [
|
178 |
+
"/Users/omar/Documents/ai_repos/ai-tutor-rag-system/env/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
179 |
+
" from .autonotebook import tqdm as notebook_tqdm\n",
|
180 |
+
"Parsing nodes: 100%|ββββββββββ| 33/33 [00:00<00:00, 290.40it/s]\n",
|
181 |
+
"Generating embeddings: 100%|ββββββββββ| 2/2 [00:01<00:00, 1.13it/s]\n"
|
182 |
+
]
|
183 |
+
}
|
184 |
+
],
|
185 |
+
"source": [
|
186 |
+
"from llama_index.core import VectorStoreIndex\n",
|
187 |
+
"from llama_index.core.node_parser import SentenceSplitter\n",
|
188 |
+
"from llama_index.embeddings.openai import OpenAIEmbedding\n",
|
189 |
+
"\n",
|
190 |
+
"index = VectorStoreIndex.from_documents(\n",
|
191 |
+
" documents,\n",
|
192 |
+
" embed_model=OpenAIEmbedding(model=\"text-embedding-3-large\", mode=\"similarity\"),\n",
|
193 |
+
" transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=400)],\n",
|
194 |
+
" show_progress=True,\n",
|
195 |
+
" use_async=True,\n",
|
196 |
+
" storage_context=storage_context,\n",
|
197 |
+
")"
|
198 |
+
]
|
199 |
+
},
|
200 |
+
{
|
201 |
+
"cell_type": "markdown",
|
202 |
+
"metadata": {},
|
203 |
+
"source": [
|
204 |
+
"### Test the DB"
|
205 |
+
]
|
206 |
+
},
|
207 |
+
{
|
208 |
+
"cell_type": "code",
|
209 |
+
"execution_count": null,
|
210 |
+
"metadata": {},
|
211 |
+
"outputs": [],
|
212 |
+
"source": [
|
213 |
+
"retriever = index.as_retriever(\n",
|
214 |
+
" similarity_top_k=10,\n",
|
215 |
+
" use_async=True,\n",
|
216 |
+
" embed_model=OpenAIEmbedding(model=\"text-embedding-3-large\", mode=\"similarity\"),\n",
|
217 |
+
")"
|
218 |
+
]
|
219 |
+
},
|
220 |
+
{
|
221 |
+
"cell_type": "code",
|
222 |
+
"execution_count": null,
|
223 |
+
"metadata": {},
|
224 |
+
"outputs": [],
|
225 |
+
"source": [
|
226 |
+
"from llama_index.core.data_structs import Node\n",
|
227 |
+
"from llama_index.core.schema import NodeWithScore, BaseNode, TextNode\n",
|
228 |
+
"\n",
|
229 |
+
"\n",
|
230 |
+
"# query = \"fine-tune a pretrained model\"\n",
|
231 |
+
"# query = \"fine-tune an llm\"\n",
|
232 |
+
"query = \"how to fine-tune an llm?\"\n",
|
233 |
+
"\n",
|
234 |
+
"nodes_context = []\n",
|
235 |
+
"nodes = retriever.retrieve(query)\n",
|
236 |
+
"\n",
|
237 |
+
"\n",
|
238 |
+
"# Filter nodes with the same ref_doc_id\n",
|
239 |
+
"def filter_nodes_by_unique_doc_id(nodes):\n",
|
240 |
+
" unique_nodes = {}\n",
|
241 |
+
" for node in nodes:\n",
|
242 |
+
" doc_id = node.node.ref_doc_id\n",
|
243 |
+
" if doc_id is not None and doc_id not in unique_nodes:\n",
|
244 |
+
" unique_nodes[doc_id] = node\n",
|
245 |
+
" return list(unique_nodes.values())\n",
|
246 |
+
"\n",
|
247 |
+
"\n",
|
248 |
+
"nodes = filter_nodes_by_unique_doc_id(nodes)\n",
|
249 |
+
"print(len(nodes))\n",
|
250 |
+
"\n",
|
251 |
+
"for node in nodes:\n",
|
252 |
+
" print(\"Node ID\\t\", node.node_id)\n",
|
253 |
+
" print(\"Title\\t\", node.metadata[\"title\"])\n",
|
254 |
+
" print(\"Text\\t\", node.text)\n",
|
255 |
+
" print(\"Score\\t\", node.score)\n",
|
256 |
+
" print(\"Metadata\\t\", node.metadata)\n",
|
257 |
+
" print(\"-_\" * 20)\n",
|
258 |
+
" if node.metadata[\"retrieve_doc\"] == True:\n",
|
259 |
+
" print(\"This node will be replaced by the document\")\n",
|
260 |
+
" doc = document_dict[node.node.ref_doc_id]\n",
|
261 |
+
" # print(doc.text)\n",
|
262 |
+
" new_node = NodeWithScore(\n",
|
263 |
+
" node=TextNode(text=doc.text, metadata=node.metadata), score=node.score\n",
|
264 |
+
" )\n",
|
265 |
+
" print(new_node.text)\n",
|
266 |
+
" nodes_context.append(new_node)\n",
|
267 |
+
" else:\n",
|
268 |
+
" nodes_context.append(node)\n",
|
269 |
+
"\n",
|
270 |
+
"print(len(nodes_context))"
|
271 |
+
]
|
272 |
+
},
|
273 |
+
{
|
274 |
+
"cell_type": "code",
|
275 |
+
"execution_count": null,
|
276 |
+
"metadata": {},
|
277 |
+
"outputs": [],
|
278 |
+
"source": [
|
279 |
+
"from llama_index.core import ChatPromptTemplate\n",
|
280 |
+
"from llama_index.core.llms import ChatMessage, MessageRole\n",
|
281 |
+
"from pydantic import BaseModel, Field\n",
|
282 |
+
"\n",
|
283 |
+
"system_prompt = (\n",
|
284 |
+
" \"You are a witty AI teacher, helpfully answering questions from students of an applied artificial intelligence course on Large Language Models (LLMs or llm). Topics covered include training models, fine-tuning models, giving 'memory' to LLMs, prompting, hallucinations and bias, vector databases, transformer architectures, embeddings, RAG frameworks, Langchain, Llama-Index, LLMs interact with tool use, AI agents, reinforcement learning with human feedback. Questions should be understood with this context.\"\n",
|
285 |
+
" \"You are provided information found in Hugging Face's documentation and the RAG course. \"\n",
|
286 |
+
" \"Only some information might be relevant to the question, so ignore the irrelevant part and use the relevant part to answer the question.\"\n",
|
287 |
+
" \"Only respond with information given to you documentation. DO NOT use additional information, even if you know the answer. \"\n",
|
288 |
+
" \"If the answer is somewhere in the documentation, answer the question (depending on the questions and the variety of relevant information in the documentation, give complete and helpful answers.\"\n",
|
289 |
+
" \"Here is the information you can use, the order is not important: \\n\\n\"\n",
|
290 |
+
" \"---------------------\\n\"\n",
|
291 |
+
" \"{context_str}\\n\"\n",
|
292 |
+
" \"---------------------\\n\\n\"\n",
|
293 |
+
" \"REMEMBER:\\n\"\n",
|
294 |
+
" \"You are a witty AI teacher, helpfully answering questions from students of an applied artificial intelligence course on Large Language Models (LLMs or llm). Topics covered include training models, fine tuning models, giving memory to LLMs, prompting, hallucinations and bias, vector databases, transformer architectures, embeddings, RAG frameworks, Langchain, making LLMs interact with tool use, AI agents, reinforcement learning with human feedback. Questions should be understood with this context.\"\n",
|
295 |
+
" \"You are provided information found in Hugging Face's documentation and the RAG course. \"\n",
|
296 |
+
" \"Here are the rules you must follow:\\n\"\n",
|
297 |
+
" \"* Only respond with information inside the documentation. DO NOT provide additional information, even if you know the answer. \"\n",
|
298 |
+
" \"* If the answer is in the documentation, answer the question (depending on the questions and the variety of relevant information in the json documentation. Your answer needs to be pertinent and not redundant giving a clear explanation as if you were a teacher. \"\n",
|
299 |
+
" \"* Only use information summarized from the documentation, do not respond otherwise. \"\n",
|
300 |
+
" \"* Do not refer to the documentation directly, but use the instructions provided within it to answer questions. \"\n",
|
301 |
+
" \"* Do not reference any links, urls or hyperlinks in your answers.\\n\"\n",
|
302 |
+
" \"* Make sure to format your answers in Markdown format, including code block and snippets.\\n\"\n",
|
303 |
+
" \"Now answer the following question: \\n\"\n",
|
304 |
+
")\n",
|
305 |
+
"\n",
|
306 |
+
"chat_text_qa_msgs: list[ChatMessage] = [\n",
|
307 |
+
" ChatMessage(role=MessageRole.SYSTEM, content=system_prompt),\n",
|
308 |
+
" ChatMessage(\n",
|
309 |
+
" role=MessageRole.USER,\n",
|
310 |
+
" content=\"{query_str}\",\n",
|
311 |
+
" ),\n",
|
312 |
+
"]\n",
|
313 |
+
"\n",
|
314 |
+
"TEXT_QA_TEMPLATE = ChatPromptTemplate(chat_text_qa_msgs)"
|
315 |
+
]
|
316 |
+
},
|
317 |
+
{
|
318 |
+
"cell_type": "code",
|
319 |
+
"execution_count": null,
|
320 |
+
"metadata": {},
|
321 |
+
"outputs": [],
|
322 |
+
"source": [
|
323 |
+
"from IPython.display import Markdown\n",
|
324 |
+
"from llama_index.core.data_structs import Node\n",
|
325 |
+
"from llama_index.core.schema import NodeWithScore\n",
|
326 |
+
"from llama_index.core import get_response_synthesizer\n",
|
327 |
+
"from llama_index.llms.gemini import Gemini\n",
|
328 |
+
"from llama_index.llms.openai import OpenAI\n",
|
329 |
+
"\n",
|
330 |
+
"# llm = Gemini(model=\"models/gemini-1.5-flash\", temperature=1, max_tokens=None)\n",
|
331 |
+
"# llm = Gemini(model=\"models/gemini-1.5-pro\", temperature=1, max_tokens=None)\n",
|
332 |
+
"# llm = OpenAI(temperature=1, model=\"gpt-3.5-turbo\", max_tokens=None)\n",
|
333 |
+
"llm = OpenAI(temperature=1, model=\"gpt-4o-mini\", max_tokens=None)\n",
|
334 |
+
"\n",
|
335 |
+
"response_synthesizer = get_response_synthesizer(\n",
|
336 |
+
" llm=llm, response_mode=\"simple_summarize\", text_qa_template=TEXT_QA_TEMPLATE\n",
|
337 |
+
")\n",
|
338 |
+
"\n",
|
339 |
+
"response = response_synthesizer.synthesize(query, nodes=nodes_context)\n",
|
340 |
+
"# print(response.response)\n",
|
341 |
+
"display(Markdown(response.response))\n",
|
342 |
+
"\n",
|
343 |
+
"# for src in response.source_nodes:\n",
|
344 |
+
"# print(src.node.ref_doc_id)\n",
|
345 |
+
"# print(\"Node ID\\t\", src.node_id)\n",
|
346 |
+
"# print(\"Title\\t\", src.metadata[\"title\"])\n",
|
347 |
+
"# print(\"Text\\t\", src.text)\n",
|
348 |
+
"# print(\"Score\\t\", src.score)\n",
|
349 |
+
"# print(\"Metadata\\t\", src.metadata)\n",
|
350 |
+
"# print(\"-_\" * 20)"
|
351 |
+
]
|
352 |
+
},
|
353 |
+
{
|
354 |
+
"cell_type": "code",
|
355 |
+
"execution_count": null,
|
356 |
+
"metadata": {},
|
357 |
+
"outputs": [],
|
358 |
+
"source": []
|
359 |
+
},
|
360 |
+
{
|
361 |
+
"cell_type": "code",
|
362 |
+
"execution_count": null,
|
363 |
+
"metadata": {},
|
364 |
+
"outputs": [],
|
365 |
+
"source": []
|
366 |
+
}
|
367 |
+
],
|
368 |
+
"metadata": {
|
369 |
+
"kernelspec": {
|
370 |
+
"display_name": "env",
|
371 |
+
"language": "python",
|
372 |
+
"name": "python3"
|
373 |
+
},
|
374 |
+
"language_info": {
|
375 |
+
"codemirror_mode": {
|
376 |
+
"name": "ipython",
|
377 |
+
"version": 3
|
378 |
+
},
|
379 |
+
"file_extension": ".py",
|
380 |
+
"mimetype": "text/x-python",
|
381 |
+
"name": "python",
|
382 |
+
"nbconvert_exporter": "python",
|
383 |
+
"pygments_lexer": "ipython3",
|
384 |
+
"version": "3.12.4"
|
385 |
+
}
|
386 |
+
},
|
387 |
+
"nbformat": 4,
|
388 |
+
"nbformat_minor": 2
|
389 |
+
}
|
data/scraping_scripts/create_jsonl_file_hf.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import uuid
|
5 |
+
|
6 |
+
import tiktoken
|
7 |
+
|
8 |
+
BASE_URL = "https://huggingface.co/docs/transformers/"
|
9 |
+
# BASE_URL = "https://huggingface.co/docs/peft/"
|
10 |
+
# BASE_URL = "https://huggingface.co/docs/trl/"
|
11 |
+
|
12 |
+
# List of directories to include (relative to the main input directory)
|
13 |
+
INCLUDED_DIRS = [
|
14 |
+
# Add more directories here as needed
|
15 |
+
]
|
16 |
+
|
17 |
+
# List of directories to exclude (relative to the main input directory)
|
18 |
+
EXCLUDED_DIRS = [
|
19 |
+
# "some_directory_to_exclude",
|
20 |
+
# Add more directories here as needed
|
21 |
+
"internal",
|
22 |
+
"main_classes",
|
23 |
+
]
|
24 |
+
|
25 |
+
# List of specific files to exclude from the root directory
|
26 |
+
EXCLUDED_ROOT_FILES = [
|
27 |
+
# "some_file_to_exclude.md",
|
28 |
+
# Add more files here as needed
|
29 |
+
]
|
30 |
+
|
31 |
+
# Set this to True to use the INCLUDED_DIRS list, or False to use the EXCLUDED_DIRS list
|
32 |
+
USE_INCLUDE_LIST = False
|
33 |
+
|
34 |
+
|
35 |
+
def extract_title(content):
|
36 |
+
# Try to find a Markdown title (# Title)
|
37 |
+
title_match = re.search(r"^#\s+(.+)$", content, re.MULTILINE)
|
38 |
+
if title_match:
|
39 |
+
return title_match.group(1).strip()
|
40 |
+
|
41 |
+
# If no Markdown title, use the first non-empty line
|
42 |
+
lines = content.split("\n")
|
43 |
+
for line in lines:
|
44 |
+
if line.strip():
|
45 |
+
return line.strip()
|
46 |
+
|
47 |
+
# If file is empty, return None
|
48 |
+
return None
|
49 |
+
|
50 |
+
|
51 |
+
def generate_url(file_path):
|
52 |
+
# Remove the file extension
|
53 |
+
path_without_extension = os.path.splitext(file_path)[0]
|
54 |
+
|
55 |
+
# Replace backslashes with forward slashes for Windows compatibility
|
56 |
+
path_with_forward_slashes = path_without_extension.replace("\\", "/")
|
57 |
+
|
58 |
+
# Combine with base URL
|
59 |
+
return BASE_URL + path_with_forward_slashes + "/"
|
60 |
+
|
61 |
+
|
62 |
+
def should_include_file(file_path):
|
63 |
+
# Check if the file is directly in the root
|
64 |
+
if os.path.dirname(file_path) == "":
|
65 |
+
return os.path.basename(file_path) not in EXCLUDED_ROOT_FILES
|
66 |
+
|
67 |
+
if USE_INCLUDE_LIST:
|
68 |
+
# Check if the file is in one of the included directories
|
69 |
+
return any(file_path.startswith(dir) for dir in INCLUDED_DIRS)
|
70 |
+
else:
|
71 |
+
# Check if the file is not in any of the excluded directories
|
72 |
+
return not any(file_path.startswith(dir) for dir in EXCLUDED_DIRS)
|
73 |
+
|
74 |
+
|
75 |
+
def num_tokens_from_string(string: str, encoding_name: str) -> int:
|
76 |
+
"""Returns the number of tokens in a text string."""
|
77 |
+
encoding = tiktoken.get_encoding(encoding_name)
|
78 |
+
num_tokens = len(
|
79 |
+
encoding.encode(
|
80 |
+
string, disallowed_special=(encoding.special_tokens_set - {"<|endoftext|>"})
|
81 |
+
)
|
82 |
+
)
|
83 |
+
return num_tokens
|
84 |
+
|
85 |
+
|
86 |
+
def remove_copyright_header(content):
|
87 |
+
# Pattern to match the copyright header
|
88 |
+
header_pattern = re.compile(r"<!--Copyright.*?-->\s*", re.DOTALL)
|
89 |
+
|
90 |
+
# Remove the header
|
91 |
+
cleaned_content = header_pattern.sub("", content, count=1)
|
92 |
+
|
93 |
+
return cleaned_content.strip()
|
94 |
+
|
95 |
+
|
96 |
+
def process_md_files(directory):
|
97 |
+
jsonl_data = []
|
98 |
+
|
99 |
+
for root, _, files in os.walk(directory):
|
100 |
+
for file in files:
|
101 |
+
if file.endswith(".md") or file.endswith(".mdx"):
|
102 |
+
file_path = os.path.join(root, file)
|
103 |
+
relative_path = os.path.relpath(file_path, directory)
|
104 |
+
|
105 |
+
# Only process the file if it should be included
|
106 |
+
if should_include_file(relative_path):
|
107 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
108 |
+
content = f.read()
|
109 |
+
|
110 |
+
title = extract_title(content)
|
111 |
+
token_count = num_tokens_from_string(content, "cl100k_base")
|
112 |
+
if token_count < 100:
|
113 |
+
continue
|
114 |
+
cleaned_content = remove_copyright_header(content)
|
115 |
+
|
116 |
+
json_object = {
|
117 |
+
"tokens": token_count,
|
118 |
+
"doc_id": str(uuid.uuid4()),
|
119 |
+
"name": (title if title else file),
|
120 |
+
"url": generate_url(relative_path),
|
121 |
+
"retrieve_doc": (True if token_count <= 8000 else False),
|
122 |
+
# "source": "TRL",
|
123 |
+
# "source": "PEFT",
|
124 |
+
"source": "HF_Transformers",
|
125 |
+
"content": cleaned_content,
|
126 |
+
}
|
127 |
+
|
128 |
+
jsonl_data.append(json_object)
|
129 |
+
|
130 |
+
return jsonl_data
|
131 |
+
|
132 |
+
|
133 |
+
def save_jsonl(data, output_file):
|
134 |
+
with open(output_file, "w", encoding="utf-8") as f:
|
135 |
+
for item in data:
|
136 |
+
json.dump(item, f, ensure_ascii=False)
|
137 |
+
f.write("\n")
|
138 |
+
|
139 |
+
|
140 |
+
# Directory where the .md files are located
|
141 |
+
input_directory = "data/transformers_md_files"
|
142 |
+
# input_directory = "data/peft_md_files"
|
143 |
+
# input_directory = "data/trl_md_files"
|
144 |
+
|
145 |
+
# Output .jsonl file
|
146 |
+
output_file = "data/transformers_data.jsonl"
|
147 |
+
# output_file = "data/peft_data.jsonl"
|
148 |
+
# output_file = "data/trl_data.jsonl"
|
149 |
+
|
150 |
+
# Process the files and save to JSONL
|
151 |
+
jsonl_data = process_md_files(input_directory)
|
152 |
+
save_jsonl(jsonl_data, output_file)
|
153 |
+
|
154 |
+
print(f"Processed {len(jsonl_data)} files and saved to {output_file}")
|
data/scraping_scripts/create_jsonl_file_llama.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import uuid
|
5 |
+
|
6 |
+
import tiktoken
|
7 |
+
|
8 |
+
BASE_URL = "https://docs.llamaindex.ai/en/stable/"
|
9 |
+
|
10 |
+
# List of directories to include (relative to the main input directory)
|
11 |
+
INCLUDED_DIRS = [
|
12 |
+
"getting_started",
|
13 |
+
"understanding",
|
14 |
+
"use_cases",
|
15 |
+
"examples",
|
16 |
+
"module_guides",
|
17 |
+
"optimizing",
|
18 |
+
]
|
19 |
+
|
20 |
+
# List of specific files to include from the root directory
|
21 |
+
INCLUDED_ROOT_FILES = [
|
22 |
+
"index.md",
|
23 |
+
# Add more files here as needed
|
24 |
+
]
|
25 |
+
|
26 |
+
|
27 |
+
def extract_title(content):
|
28 |
+
# Try to find a Markdown title (# Title)
|
29 |
+
title_match = re.search(r"^#\s+(.+)$", content, re.MULTILINE)
|
30 |
+
if title_match:
|
31 |
+
return title_match.group(1).strip()
|
32 |
+
|
33 |
+
# If no Markdown title, use the first non-empty line
|
34 |
+
lines = content.split("\n")
|
35 |
+
for line in lines:
|
36 |
+
if line.strip():
|
37 |
+
return line.strip()
|
38 |
+
|
39 |
+
# If file is empty, return None
|
40 |
+
return None
|
41 |
+
|
42 |
+
|
43 |
+
def generate_url(file_path):
|
44 |
+
# Remove the file extension
|
45 |
+
path_without_extension = os.path.splitext(file_path)[0]
|
46 |
+
|
47 |
+
# Replace backslashes with forward slashes for Windows compatibility
|
48 |
+
path_with_forward_slashes = path_without_extension.replace("\\", "/")
|
49 |
+
|
50 |
+
# Combine with base URL
|
51 |
+
return BASE_URL + path_with_forward_slashes + "/"
|
52 |
+
|
53 |
+
|
54 |
+
def should_include_file(file_path):
|
55 |
+
# Check if the file is directly in the root and in the INCLUDED_ROOT_FILES list
|
56 |
+
if os.path.dirname(file_path) == "":
|
57 |
+
return os.path.basename(file_path) in INCLUDED_ROOT_FILES
|
58 |
+
|
59 |
+
# Check if the file is in one of the included directories
|
60 |
+
return any(file_path.startswith(dir) for dir in INCLUDED_DIRS)
|
61 |
+
|
62 |
+
|
63 |
+
def num_tokens_from_string(string: str, encoding_name: str) -> int:
|
64 |
+
"""Returns the number of tokens in a text string."""
|
65 |
+
encoding = tiktoken.get_encoding(encoding_name)
|
66 |
+
num_tokens = len(
|
67 |
+
encoding.encode(
|
68 |
+
string, disallowed_special=(encoding.special_tokens_set - {"<|endoftext|>"})
|
69 |
+
)
|
70 |
+
)
|
71 |
+
return num_tokens
|
72 |
+
|
73 |
+
|
74 |
+
def process_md_files(directory):
|
75 |
+
jsonl_data = []
|
76 |
+
|
77 |
+
for root, _, files in os.walk(directory):
|
78 |
+
for file in files:
|
79 |
+
if file.endswith(".md") or file.endswith(".mdx"):
|
80 |
+
file_path = os.path.join(root, file)
|
81 |
+
relative_path = os.path.relpath(file_path, directory)
|
82 |
+
|
83 |
+
# Only process the file if it should be included
|
84 |
+
if should_include_file(relative_path):
|
85 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
86 |
+
content = f.read()
|
87 |
+
|
88 |
+
title = extract_title(content)
|
89 |
+
token_count = num_tokens_from_string(content, "cl100k_base")
|
90 |
+
|
91 |
+
json_object = {
|
92 |
+
"tokens": token_count,
|
93 |
+
"doc_id": str(uuid.uuid4()),
|
94 |
+
"name": (title if title else file),
|
95 |
+
"url": generate_url(relative_path),
|
96 |
+
"retrieve_doc": (True if token_count <= 8000 else False),
|
97 |
+
"source": "LlamaIndex",
|
98 |
+
"content": content,
|
99 |
+
}
|
100 |
+
|
101 |
+
jsonl_data.append(json_object)
|
102 |
+
|
103 |
+
return jsonl_data
|
104 |
+
|
105 |
+
|
106 |
+
def save_jsonl(data, output_file):
|
107 |
+
with open(output_file, "w", encoding="utf-8") as f:
|
108 |
+
for item in data:
|
109 |
+
json.dump(item, f, ensure_ascii=False)
|
110 |
+
f.write("\n")
|
111 |
+
|
112 |
+
|
113 |
+
# Directory where the .md files are located
|
114 |
+
input_directory = "data/llama_index_md_files"
|
115 |
+
|
116 |
+
# Output .jsonl file
|
117 |
+
output_file = "data/llama_index_data.jsonl"
|
118 |
+
|
119 |
+
# Process the files and save to JSONL
|
120 |
+
jsonl_data = process_md_files(input_directory)
|
121 |
+
save_jsonl(jsonl_data, output_file)
|
122 |
+
|
123 |
+
print(f"Processed {len(jsonl_data)} files and saved to {output_file}")
|
data/scraping_scripts/get_md_files_from_repo.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import time
|
5 |
+
|
6 |
+
import nbformat
|
7 |
+
import requests
|
8 |
+
from nbconvert import MarkdownExporter
|
9 |
+
|
10 |
+
# GitHub repository information
|
11 |
+
owner = "huggingface"
|
12 |
+
repo = "transformers"
|
13 |
+
path = "docs/source/en"
|
14 |
+
|
15 |
+
# owner = "huggingface"
|
16 |
+
# repo = "peft"
|
17 |
+
# path = "docs/source"
|
18 |
+
|
19 |
+
# owner = "huggingface"
|
20 |
+
# repo = "trl"
|
21 |
+
# path = "docs/source"
|
22 |
+
|
23 |
+
# GitHub repository information
|
24 |
+
# owner = "run-llama"
|
25 |
+
# repo = "llama_index"
|
26 |
+
# path = "docs/docs"
|
27 |
+
|
28 |
+
# GitHub API endpoint for the repository contents
|
29 |
+
api_url = f"https://api.github.com/repos/{owner}/{repo}/contents/{path}"
|
30 |
+
|
31 |
+
# GitHub Personal Access Token (replace with your own token)
|
32 |
+
github_token = "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"
|
33 |
+
|
34 |
+
# Headers for authenticated requests
|
35 |
+
headers = {
|
36 |
+
"Authorization": f"token {github_token}",
|
37 |
+
"Accept": "application/vnd.github.v3+json",
|
38 |
+
}
|
39 |
+
|
40 |
+
# Maximum number of retries
|
41 |
+
MAX_RETRIES = 5
|
42 |
+
|
43 |
+
|
44 |
+
def check_rate_limit():
|
45 |
+
rate_limit_url = "https://api.github.com/rate_limit"
|
46 |
+
response = requests.get(rate_limit_url, headers=headers)
|
47 |
+
data = response.json()
|
48 |
+
remaining = data["resources"]["core"]["remaining"]
|
49 |
+
reset_time = data["resources"]["core"]["reset"]
|
50 |
+
|
51 |
+
if remaining < 10: # Adjust this threshold as needed
|
52 |
+
wait_time = reset_time - time.time()
|
53 |
+
print(f"Rate limit nearly exceeded. Waiting for {wait_time:.2f} seconds.")
|
54 |
+
time.sleep(wait_time + 1) # Add 1 second buffer
|
55 |
+
|
56 |
+
|
57 |
+
def get_files_in_directory(api_url, retries=0):
|
58 |
+
try:
|
59 |
+
check_rate_limit()
|
60 |
+
response = requests.get(api_url, headers=headers)
|
61 |
+
response.raise_for_status()
|
62 |
+
return response.json()
|
63 |
+
except requests.exceptions.RequestException as e:
|
64 |
+
if retries < MAX_RETRIES:
|
65 |
+
wait_time = (2**retries) + random.random()
|
66 |
+
print(
|
67 |
+
f"Error fetching directory contents: {e}. Retrying in {wait_time:.2f} seconds..."
|
68 |
+
)
|
69 |
+
time.sleep(wait_time)
|
70 |
+
return get_files_in_directory(api_url, retries + 1)
|
71 |
+
else:
|
72 |
+
print(
|
73 |
+
f"Failed to fetch directory contents after {MAX_RETRIES} retries: {e}"
|
74 |
+
)
|
75 |
+
return []
|
76 |
+
|
77 |
+
|
78 |
+
def download_file(file_url, file_path, retries=0):
|
79 |
+
try:
|
80 |
+
check_rate_limit()
|
81 |
+
response = requests.get(file_url, headers=headers)
|
82 |
+
response.raise_for_status()
|
83 |
+
with open(file_path, "wb") as file:
|
84 |
+
file.write(response.content)
|
85 |
+
except requests.exceptions.RequestException as e:
|
86 |
+
if retries < MAX_RETRIES:
|
87 |
+
wait_time = (2**retries) + random.random()
|
88 |
+
print(
|
89 |
+
f"Error downloading file: {e}. Retrying in {wait_time:.2f} seconds..."
|
90 |
+
)
|
91 |
+
time.sleep(wait_time)
|
92 |
+
download_file(file_url, file_path, retries + 1)
|
93 |
+
else:
|
94 |
+
print(f"Failed to download file after {MAX_RETRIES} retries: {e}")
|
95 |
+
|
96 |
+
|
97 |
+
def convert_ipynb_to_md(ipynb_path, md_path):
|
98 |
+
with open(ipynb_path, "r", encoding="utf-8") as f:
|
99 |
+
notebook = nbformat.read(f, as_version=4)
|
100 |
+
|
101 |
+
exporter = MarkdownExporter()
|
102 |
+
markdown, _ = exporter.from_notebook_node(notebook)
|
103 |
+
|
104 |
+
with open(md_path, "w", encoding="utf-8") as f:
|
105 |
+
f.write(markdown)
|
106 |
+
|
107 |
+
|
108 |
+
def fetch_files(api_url, local_dir):
|
109 |
+
files = get_files_in_directory(api_url)
|
110 |
+
for file in files:
|
111 |
+
if file["type"] == "file" and file["name"].endswith((".md", ".mdx", ".ipynb")):
|
112 |
+
file_url = file["download_url"]
|
113 |
+
file_name = file["name"]
|
114 |
+
file_path = os.path.join(local_dir, file_name)
|
115 |
+
print(f"Downloading {file_name}...")
|
116 |
+
download_file(file_url, file_path)
|
117 |
+
|
118 |
+
if file_name.endswith(".ipynb"):
|
119 |
+
md_file_name = file_name.replace(".ipynb", ".md")
|
120 |
+
md_file_path = os.path.join(local_dir, md_file_name)
|
121 |
+
print(f"Converting {file_name} to markdown...")
|
122 |
+
convert_ipynb_to_md(file_path, md_file_path)
|
123 |
+
os.remove(file_path) # Remove the .ipynb file after conversion
|
124 |
+
elif file["type"] == "dir":
|
125 |
+
subdir = os.path.join(local_dir, file["name"])
|
126 |
+
os.makedirs(subdir, exist_ok=True)
|
127 |
+
fetch_files(file["url"], subdir)
|
128 |
+
|
129 |
+
|
130 |
+
# Local directory to save the files
|
131 |
+
local_dir = f"data/{repo}_md_files"
|
132 |
+
os.makedirs(local_dir, exist_ok=True)
|
133 |
+
|
134 |
+
# Start fetching files
|
135 |
+
fetch_files(api_url, local_dir)
|
136 |
+
|
137 |
+
print("All files have been downloaded and converted.")
|
scripts/call_openai.py
DELETED
@@ -1,79 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import logging
|
3 |
-
|
4 |
-
import instructor
|
5 |
-
import openai
|
6 |
-
from openai import OpenAI, AsyncOpenAI
|
7 |
-
from dotenv import load_dotenv
|
8 |
-
|
9 |
-
logger = logging.getLogger(__name__)
|
10 |
-
logging.basicConfig(level=logging.INFO)
|
11 |
-
|
12 |
-
load_dotenv(".env")
|
13 |
-
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
14 |
-
|
15 |
-
|
16 |
-
def api_function_call(
|
17 |
-
system_message,
|
18 |
-
query: str,
|
19 |
-
model: str = "gpt-4o",
|
20 |
-
response_model=None,
|
21 |
-
max_retries: int = 0,
|
22 |
-
stream: bool = False,
|
23 |
-
):
|
24 |
-
|
25 |
-
client = instructor.patch(OpenAI())
|
26 |
-
try:
|
27 |
-
message_data = {
|
28 |
-
"model": model,
|
29 |
-
"messages": [
|
30 |
-
{"role": "system", "content": system_message},
|
31 |
-
{"role": "user", "content": query},
|
32 |
-
],
|
33 |
-
"max_retries": max_retries,
|
34 |
-
"stream": stream,
|
35 |
-
}
|
36 |
-
if response_model is not None:
|
37 |
-
message_data["response_model"] = response_model
|
38 |
-
|
39 |
-
response = client.chat.completions.create(**message_data)
|
40 |
-
error = False
|
41 |
-
|
42 |
-
except openai.BadRequestError:
|
43 |
-
error = True
|
44 |
-
logger.exception("Invalid request to OpenAI API. See traceback:")
|
45 |
-
error_message = (
|
46 |
-
"Something went wrong while connecting with OpenAI, try again soon!"
|
47 |
-
)
|
48 |
-
return error_message, error
|
49 |
-
|
50 |
-
except openai.RateLimitError:
|
51 |
-
error = True
|
52 |
-
logger.exception("RateLimit error from OpenAI. See traceback:")
|
53 |
-
error_message = "OpenAI servers seem to be overloaded, try again later!"
|
54 |
-
return error_message, error
|
55 |
-
|
56 |
-
except Exception as e:
|
57 |
-
error = True
|
58 |
-
logger.exception(
|
59 |
-
"Some kind of error happened trying to generate the response. See traceback:"
|
60 |
-
)
|
61 |
-
error_message = (
|
62 |
-
"Something went wrong with connecting with OpenAI, try again soon!"
|
63 |
-
)
|
64 |
-
return error_message, error
|
65 |
-
|
66 |
-
if stream is True and response_model is None:
|
67 |
-
|
68 |
-
def answer_generator():
|
69 |
-
for chunk in response:
|
70 |
-
token = chunk.choices[0].delta.content
|
71 |
-
|
72 |
-
token = "" if token is None else token
|
73 |
-
|
74 |
-
yield token
|
75 |
-
|
76 |
-
return answer_generator(), error
|
77 |
-
|
78 |
-
else:
|
79 |
-
return response, error
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/create_db.ipynb
DELETED
The diff for this file is too large to render.
See raw diff
|
|