Support Ollama (#261)

### What problem does this PR solve?

Issue link:#221

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
KevinHuSh
2024-04-08 19:20:57 +08:00
committed by GitHub
parent 265a7a283a
commit 3708b97db9
15 changed files with 234 additions and 43 deletions

View File

@ -16,13 +16,12 @@
from zhipuai import ZhipuAI
import os
from abc import ABC
from ollama import Client
import dashscope
from openai import OpenAI
from FlagEmbedding import FlagModel
import torch
import numpy as np
from huggingface_hub import snapshot_download
from api.utils.file_utils import get_project_base_directory
from rag.utils import num_tokens_from_string
@ -150,3 +149,24 @@ class ZhipuEmbed(Base):
res = self.client.embeddings.create(input=text,
model=self.model_name)
return np.array(res.data[0].embedding), res.usage.total_tokens
class OllamaEmbed(Base):
def __init__(self, key, model_name, **kwargs):
self.client = Client(host=kwargs["base_url"])
self.model_name = model_name
def encode(self, texts: list, batch_size=32):
arr = []
tks_num = 0
for txt in texts:
res = self.client.embeddings(prompt=txt,
model=self.model_name)
arr.append(res["embedding"])
tks_num += 128
return np.array(arr), tks_num
def encode_queries(self, text):
res = self.client.embeddings(prompt=text,
model=self.model_name)
return np.array(res["embedding"]), 128