Support Ollama (#261)

### What problem does this PR solve?

Issue link:#221

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
KevinHuSh
2024-04-08 19:20:57 +08:00
committed by GitHub
parent 265a7a283a
commit 3708b97db9
15 changed files with 234 additions and 43 deletions

View File

@ -19,7 +19,7 @@ from .cv_model import *
EmbeddingModel = {
"Local": HuEmbedding,
"Ollama": OllamaEmbed,
"OpenAI": OpenAIEmbed,
"Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
"ZHIPU-AI": ZhipuEmbed,
@ -29,7 +29,7 @@ EmbeddingModel = {
CvModel = {
"OpenAI": GptV4,
"Local": LocalCV,
"Ollama": OllamaCV,
"Tongyi-Qianwen": QWenCV,
"ZHIPU-AI": Zhipu4V,
"Moonshot": LocalCV
@ -40,7 +40,7 @@ ChatModel = {
"OpenAI": GptTurbo,
"ZHIPU-AI": ZhipuChat,
"Tongyi-Qianwen": QWenChat,
"Local": LocalLLM,
"Ollama": OllamaChat,
"Moonshot": MoonshotChat
}

View File

@ -18,6 +18,7 @@ from dashscope import Generation
from abc import ABC
from openai import OpenAI
import openai
from ollama import Client
from rag.nlp import is_english
from rag.utils import num_tokens_from_string
@ -129,6 +130,32 @@ class ZhipuChat(Base):
return "**ERROR**: " + str(e), 0
class OllamaChat(Base):
def __init__(self, key, model_name, **kwargs):
self.client = Client(host=kwargs["base_url"])
self.model_name = model_name
def chat(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
try:
options = {"temperature": gen_conf.get("temperature", 0.1),
"num_predict": gen_conf.get("max_tokens", 128),
"top_k": gen_conf.get("top_p", 0.3),
"presence_penalty": gen_conf.get("presence_penalty", 0.4),
"frequency_penalty": gen_conf.get("frequency_penalty", 0.7),
}
response = self.client.chat(
model=self.model_name,
messages=history,
options=options
)
ans = response["message"]["content"].strip()
return ans, response["eval_count"]
except Exception as e:
return "**ERROR**: " + str(e), 0
class LocalLLM(Base):
class RPCProxy:
def __init__(self, host, port):

View File

@ -16,7 +16,7 @@
from zhipuai import ZhipuAI
import io
from abc import ABC
from ollama import Client
from PIL import Image
from openai import OpenAI
import os
@ -140,6 +140,28 @@ class Zhipu4V(Base):
return res.choices[0].message.content.strip(), res.usage.total_tokens
class OllamaCV(Base):
def __init__(self, key, model_name, lang="Chinese", **kwargs):
self.client = Client(host=kwargs["base_url"])
self.model_name = model_name
self.lang = lang
def describe(self, image, max_tokens=1024):
prompt = self.prompt("")
try:
options = {"num_predict": max_tokens}
response = self.client.generate(
model=self.model_name,
prompt=prompt[0]["content"][1]["text"],
images=[image],
options=options
)
ans = response["response"].strip()
return ans, 128
except Exception as e:
return "**ERROR**: " + str(e), 0
class LocalCV(Base):
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
pass

View File

@ -16,13 +16,12 @@
from zhipuai import ZhipuAI
import os
from abc import ABC
from ollama import Client
import dashscope
from openai import OpenAI
from FlagEmbedding import FlagModel
import torch
import numpy as np
from huggingface_hub import snapshot_download
from api.utils.file_utils import get_project_base_directory
from rag.utils import num_tokens_from_string
@ -150,3 +149,24 @@ class ZhipuEmbed(Base):
res = self.client.embeddings.create(input=text,
model=self.model_name)
return np.array(res.data[0].embedding), res.usage.total_tokens
class OllamaEmbed(Base):
def __init__(self, key, model_name, **kwargs):
self.client = Client(host=kwargs["base_url"])
self.model_name = model_name
def encode(self, texts: list, batch_size=32):
arr = []
tks_num = 0
for txt in texts:
res = self.client.embeddings(prompt=txt,
model=self.model_name)
arr.append(res["embedding"])
tks_num += 128
return np.array(arr), tks_num
def encode_queries(self, text):
res = self.client.embeddings(prompt=text,
model=self.model_name)
return np.array(res["embedding"]), 128

View File

@ -23,7 +23,8 @@ import re
import sys
import traceback
from functools import partial
import signal
from contextlib import contextmanager
from rag.settings import database_logger
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
@ -97,8 +98,21 @@ def collect(comm, mod, tm):
cron_logger.info("TOTAL:{}, To:{}".format(len(tasks), mtm))
return tasks
@contextmanager
def timeout(time):
# Register a function to raise a TimeoutError on the signal.
signal.signal(signal.SIGALRM, raise_timeout)
# Schedule the signal to be sent after ``time``.
signal.alarm(time)
yield
def raise_timeout(signum, frame):
raise TimeoutError
def build(row):
from timeit import default_timer as timer
if row["size"] > DOC_MAXIMUM_SIZE:
set_progress(row["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
@ -111,11 +125,14 @@ def build(row):
row["to_page"])
chunker = FACTORY[row["parser_id"].lower()]
try:
cron_logger.info(
"Chunkking {}/{}".format(row["location"], row["name"]))
cks = chunker.chunk(row["name"], binary=MINIO.get(row["kb_id"], row["location"]), from_page=row["from_page"],
st = timer()
with timeout(30):
binary = MINIO.get(row["kb_id"], row["location"])
cks = chunker.chunk(row["name"], binary=binary, from_page=row["from_page"],
to_page=row["to_page"], lang=row["language"], callback=callback,
kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"])
cron_logger.info(
"Chunkking({}) {}/{}".format(timer()-st, row["location"], row["name"]))
except Exception as e:
if re.search("(No such file|not found)", str(e)):
callback(-1, "Can not find file <%s>" % row["name"])