limcheekin commited on
Commit
9190101
1 Parent(s): df03c96

feat: updated open.text.embeddings.server

Browse files
Files changed (1) hide show
  1. open/text/embeddings/server/app.py +10 -11
open/text/embeddings/server/app.py CHANGED
@@ -61,13 +61,14 @@ embeddings = None
61
 
62
 
63
  def _create_embedding(
64
- request: CreateEmbeddingRequest
 
65
  ):
66
  global embeddings
67
 
68
  if embeddings is None:
69
- if request.model and request.model != "text-embedding-ada-002":
70
- model_name = request.model
71
  else:
72
  model_name = os.environ["MODEL"]
73
  print("Loading model:", model_name)
@@ -92,11 +93,11 @@ def _create_embedding(
92
  embeddings = HuggingFaceEmbeddings(
93
  model_name=model_name, encode_kwargs=encode_kwargs)
94
 
95
- if isinstance(request.input, str):
96
- return CreateEmbeddingResponse(data=[Embedding(embedding=embeddings.embed_query(request.input))])
97
  else:
98
  data = [Embedding(embedding=embedding)
99
- for embedding in embeddings.embed_documents(request.input)]
100
  return CreateEmbeddingResponse(data=data)
101
 
102
 
@@ -107,8 +108,6 @@ def _create_embedding(
107
  async def create_embedding(
108
  request: CreateEmbeddingRequest
109
  ):
110
- return _create_embedding(request)
111
- # throw TypeError: 'CreateEmbeddingResponse' object is not callable?
112
- # return await run_in_threadpool(
113
- # _create_embedding(request)
114
- # )
 
61
 
62
 
63
  def _create_embedding(
64
+ model: Optional[str],
65
+ input: Union[str, List[str]]
66
  ):
67
  global embeddings
68
 
69
  if embeddings is None:
70
+ if model and model != "text-embedding-ada-002":
71
+ model_name = model
72
  else:
73
  model_name = os.environ["MODEL"]
74
  print("Loading model:", model_name)
 
93
  embeddings = HuggingFaceEmbeddings(
94
  model_name=model_name, encode_kwargs=encode_kwargs)
95
 
96
+ if isinstance(input, str):
97
+ return CreateEmbeddingResponse(data=[Embedding(embedding=embeddings.embed_query(input))])
98
  else:
99
  data = [Embedding(embedding=embedding)
100
+ for embedding in embeddings.embed_documents(input)]
101
  return CreateEmbeddingResponse(data=data)
102
 
103
 
 
108
  async def create_embedding(
109
  request: CreateEmbeddingRequest
110
  ):
111
+ return await run_in_threadpool(
112
+ _create_embedding, **request.dict(exclude={"user"})
113
+ )