Update splade_encoder.py
Browse files- splade_encoder.py +3 -4
splade_encoder.py
CHANGED
@@ -26,19 +26,18 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
26 |
SOFTWARE.
|
27 |
"""
|
28 |
|
|
|
29 |
import logging
|
|
|
30 |
from typing import Dict, List, Optional
|
31 |
-
|
32 |
-
import torch
|
33 |
from scipy.sparse import csr_array, vstack
|
34 |
-
|
35 |
from milvus_model.base import BaseEmbeddingFunction
|
|
|
36 |
from milvus_model.utils import import_transformers, import_scipy, import_torch
|
37 |
|
38 |
import_torch()
|
39 |
import_scipy()
|
40 |
import_transformers()
|
41 |
-
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
42 |
|
43 |
logger = logging.getLogger(__name__)
|
44 |
logger.setLevel(logging.DEBUG)
|
|
|
26 |
SOFTWARE.
|
27 |
"""
|
28 |
|
29 |
+
import torch
|
30 |
import logging
|
31 |
+
import onnxruntime as ort
|
32 |
from typing import Dict, List, Optional
|
|
|
|
|
33 |
from scipy.sparse import csr_array, vstack
|
|
|
34 |
from milvus_model.base import BaseEmbeddingFunction
|
35 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
36 |
from milvus_model.utils import import_transformers, import_scipy, import_torch
|
37 |
|
38 |
import_torch()
|
39 |
import_scipy()
|
40 |
import_transformers()
|
|
|
41 |
|
42 |
logger = logging.getLogger(__name__)
|
43 |
logger.setLevel(logging.DEBUG)
|