mindspark121 commited on
Commit
f411a8f
Β·
verified Β·
1 Parent(s): d08a783

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from sentence_transformers import SentenceTransformer
@@ -7,7 +8,10 @@ import os
7
  import logging
8
  from groq import Groq
9
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
10
-
 
 
 
11
  # βœ… Initialize FastAPI
12
  app = FastAPI()
13
 
@@ -19,11 +23,12 @@ if not GROQ_API_KEY:
19
  client = Groq(api_key=GROQ_API_KEY) # βœ… Ensure the API key is passed correctly
20
 
21
 
22
- # βœ… Load AI Models
23
- similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
24
- embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
25
- summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base")
26
- summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base")
 
27
 
28
  # βœ… Load datasets
29
  try:
 
1
+ import os
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
  from sentence_transformers import SentenceTransformer
 
8
  import logging
9
  from groq import Groq
10
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
11
+ # βœ… Set a writable cache directory
12
+ os.environ["HF_HOME"] = "/tmp/huggingface"
13
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
14
+ os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/huggingface"
15
  # βœ… Initialize FastAPI
16
  app = FastAPI()
17
 
 
23
  client = Groq(api_key=GROQ_API_KEY) # βœ… Ensure the API key is passed correctly
24
 
25
 
26
+ # βœ… Load AI Models (Now uses /tmp/huggingface as cache)
27
+ similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", cache_folder="/tmp/huggingface")
28
+ embedding_model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder="/tmp/huggingface")
29
+ summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")
30
+ summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")
31
+
32
 
33
  # βœ… Load datasets
34
  try: