mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Support Ollama (#261)
### What problem does this PR solve? Issue link:#221 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -19,7 +19,7 @@ from .cv_model import *
|
||||
|
||||
|
||||
EmbeddingModel = {
|
||||
"Local": HuEmbedding,
|
||||
"Ollama": OllamaEmbed,
|
||||
"OpenAI": OpenAIEmbed,
|
||||
"Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
|
||||
"ZHIPU-AI": ZhipuEmbed,
|
||||
@ -29,7 +29,7 @@ EmbeddingModel = {
|
||||
|
||||
CvModel = {
|
||||
"OpenAI": GptV4,
|
||||
"Local": LocalCV,
|
||||
"Ollama": OllamaCV,
|
||||
"Tongyi-Qianwen": QWenCV,
|
||||
"ZHIPU-AI": Zhipu4V,
|
||||
"Moonshot": LocalCV
|
||||
@ -40,7 +40,7 @@ ChatModel = {
|
||||
"OpenAI": GptTurbo,
|
||||
"ZHIPU-AI": ZhipuChat,
|
||||
"Tongyi-Qianwen": QWenChat,
|
||||
"Local": LocalLLM,
|
||||
"Ollama": OllamaChat,
|
||||
"Moonshot": MoonshotChat
|
||||
}
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ from dashscope import Generation
|
||||
from abc import ABC
|
||||
from openai import OpenAI
|
||||
import openai
|
||||
from ollama import Client
|
||||
from rag.nlp import is_english
|
||||
from rag.utils import num_tokens_from_string
|
||||
|
||||
@ -129,6 +130,32 @@ class ZhipuChat(Base):
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
|
||||
class OllamaChat(Base):
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
self.client = Client(host=kwargs["base_url"])
|
||||
self.model_name = model_name
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
try:
|
||||
options = {"temperature": gen_conf.get("temperature", 0.1),
|
||||
"num_predict": gen_conf.get("max_tokens", 128),
|
||||
"top_k": gen_conf.get("top_p", 0.3),
|
||||
"presence_penalty": gen_conf.get("presence_penalty", 0.4),
|
||||
"frequency_penalty": gen_conf.get("frequency_penalty", 0.7),
|
||||
}
|
||||
response = self.client.chat(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
options=options
|
||||
)
|
||||
ans = response["message"]["content"].strip()
|
||||
return ans, response["eval_count"]
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
|
||||
class LocalLLM(Base):
|
||||
class RPCProxy:
|
||||
def __init__(self, host, port):
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
from zhipuai import ZhipuAI
|
||||
import io
|
||||
from abc import ABC
|
||||
|
||||
from ollama import Client
|
||||
from PIL import Image
|
||||
from openai import OpenAI
|
||||
import os
|
||||
@ -140,6 +140,28 @@ class Zhipu4V(Base):
|
||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||
|
||||
|
||||
class OllamaCV(Base):
|
||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||
self.client = Client(host=kwargs["base_url"])
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
|
||||
def describe(self, image, max_tokens=1024):
|
||||
prompt = self.prompt("")
|
||||
try:
|
||||
options = {"num_predict": max_tokens}
|
||||
response = self.client.generate(
|
||||
model=self.model_name,
|
||||
prompt=prompt[0]["content"][1]["text"],
|
||||
images=[image],
|
||||
options=options
|
||||
)
|
||||
ans = response["response"].strip()
|
||||
return ans, 128
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
|
||||
class LocalCV(Base):
|
||||
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
|
||||
pass
|
||||
|
||||
@ -16,13 +16,12 @@
|
||||
from zhipuai import ZhipuAI
|
||||
import os
|
||||
from abc import ABC
|
||||
|
||||
from ollama import Client
|
||||
import dashscope
|
||||
from openai import OpenAI
|
||||
from FlagEmbedding import FlagModel
|
||||
import torch
|
||||
import numpy as np
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from rag.utils import num_tokens_from_string
|
||||
@ -150,3 +149,24 @@ class ZhipuEmbed(Base):
|
||||
res = self.client.embeddings.create(input=text,
|
||||
model=self.model_name)
|
||||
return np.array(res.data[0].embedding), res.usage.total_tokens
|
||||
|
||||
|
||||
class OllamaEmbed(Base):
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
self.client = Client(host=kwargs["base_url"])
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list, batch_size=32):
|
||||
arr = []
|
||||
tks_num = 0
|
||||
for txt in texts:
|
||||
res = self.client.embeddings(prompt=txt,
|
||||
model=self.model_name)
|
||||
arr.append(res["embedding"])
|
||||
tks_num += 128
|
||||
return np.array(arr), tks_num
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embeddings(prompt=text,
|
||||
model=self.model_name)
|
||||
return np.array(res["embedding"]), 128
|
||||
|
||||
@ -23,7 +23,8 @@ import re
|
||||
import sys
|
||||
import traceback
|
||||
from functools import partial
|
||||
|
||||
import signal
|
||||
from contextlib import contextmanager
|
||||
from rag.settings import database_logger
|
||||
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
||||
|
||||
@ -97,8 +98,21 @@ def collect(comm, mod, tm):
|
||||
cron_logger.info("TOTAL:{}, To:{}".format(len(tasks), mtm))
|
||||
return tasks
|
||||
|
||||
@contextmanager
|
||||
def timeout(time):
|
||||
# Register a function to raise a TimeoutError on the signal.
|
||||
signal.signal(signal.SIGALRM, raise_timeout)
|
||||
# Schedule the signal to be sent after ``time``.
|
||||
signal.alarm(time)
|
||||
yield
|
||||
|
||||
|
||||
def raise_timeout(signum, frame):
|
||||
raise TimeoutError
|
||||
|
||||
|
||||
def build(row):
|
||||
from timeit import default_timer as timer
|
||||
if row["size"] > DOC_MAXIMUM_SIZE:
|
||||
set_progress(row["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
|
||||
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
|
||||
@ -111,11 +125,14 @@ def build(row):
|
||||
row["to_page"])
|
||||
chunker = FACTORY[row["parser_id"].lower()]
|
||||
try:
|
||||
cron_logger.info(
|
||||
"Chunkking {}/{}".format(row["location"], row["name"]))
|
||||
cks = chunker.chunk(row["name"], binary=MINIO.get(row["kb_id"], row["location"]), from_page=row["from_page"],
|
||||
st = timer()
|
||||
with timeout(30):
|
||||
binary = MINIO.get(row["kb_id"], row["location"])
|
||||
cks = chunker.chunk(row["name"], binary=binary, from_page=row["from_page"],
|
||||
to_page=row["to_page"], lang=row["language"], callback=callback,
|
||||
kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"])
|
||||
cron_logger.info(
|
||||
"Chunkking({}) {}/{}".format(timer()-st, row["location"], row["name"]))
|
||||
except Exception as e:
|
||||
if re.search("(No such file|not found)", str(e)):
|
||||
callback(-1, "Can not find file <%s>" % row["name"])
|
||||
|
||||
Reference in New Issue
Block a user