mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Support Ollama (#261)
### What problem does this PR solve? Issue link:#221 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -19,7 +19,7 @@ from .cv_model import *
|
||||
|
||||
|
||||
EmbeddingModel = {
|
||||
"Local": HuEmbedding,
|
||||
"Ollama": OllamaEmbed,
|
||||
"OpenAI": OpenAIEmbed,
|
||||
"Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
|
||||
"ZHIPU-AI": ZhipuEmbed,
|
||||
@ -29,7 +29,7 @@ EmbeddingModel = {
|
||||
|
||||
CvModel = {
|
||||
"OpenAI": GptV4,
|
||||
"Local": LocalCV,
|
||||
"Ollama": OllamaCV,
|
||||
"Tongyi-Qianwen": QWenCV,
|
||||
"ZHIPU-AI": Zhipu4V,
|
||||
"Moonshot": LocalCV
|
||||
@ -40,7 +40,7 @@ ChatModel = {
|
||||
"OpenAI": GptTurbo,
|
||||
"ZHIPU-AI": ZhipuChat,
|
||||
"Tongyi-Qianwen": QWenChat,
|
||||
"Local": LocalLLM,
|
||||
"Ollama": OllamaChat,
|
||||
"Moonshot": MoonshotChat
|
||||
}
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ from dashscope import Generation
|
||||
from abc import ABC
|
||||
from openai import OpenAI
|
||||
import openai
|
||||
from ollama import Client
|
||||
from rag.nlp import is_english
|
||||
from rag.utils import num_tokens_from_string
|
||||
|
||||
@ -129,6 +130,32 @@ class ZhipuChat(Base):
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
|
||||
class OllamaChat(Base):
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
self.client = Client(host=kwargs["base_url"])
|
||||
self.model_name = model_name
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
try:
|
||||
options = {"temperature": gen_conf.get("temperature", 0.1),
|
||||
"num_predict": gen_conf.get("max_tokens", 128),
|
||||
"top_k": gen_conf.get("top_p", 0.3),
|
||||
"presence_penalty": gen_conf.get("presence_penalty", 0.4),
|
||||
"frequency_penalty": gen_conf.get("frequency_penalty", 0.7),
|
||||
}
|
||||
response = self.client.chat(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
options=options
|
||||
)
|
||||
ans = response["message"]["content"].strip()
|
||||
return ans, response["eval_count"]
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
|
||||
class LocalLLM(Base):
|
||||
class RPCProxy:
|
||||
def __init__(self, host, port):
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
from zhipuai import ZhipuAI
|
||||
import io
|
||||
from abc import ABC
|
||||
|
||||
from ollama import Client
|
||||
from PIL import Image
|
||||
from openai import OpenAI
|
||||
import os
|
||||
@ -140,6 +140,28 @@ class Zhipu4V(Base):
|
||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||
|
||||
|
||||
class OllamaCV(Base):
|
||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||
self.client = Client(host=kwargs["base_url"])
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
|
||||
def describe(self, image, max_tokens=1024):
|
||||
prompt = self.prompt("")
|
||||
try:
|
||||
options = {"num_predict": max_tokens}
|
||||
response = self.client.generate(
|
||||
model=self.model_name,
|
||||
prompt=prompt[0]["content"][1]["text"],
|
||||
images=[image],
|
||||
options=options
|
||||
)
|
||||
ans = response["response"].strip()
|
||||
return ans, 128
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
|
||||
class LocalCV(Base):
|
||||
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
|
||||
pass
|
||||
|
||||
@ -16,13 +16,12 @@
|
||||
from zhipuai import ZhipuAI
|
||||
import os
|
||||
from abc import ABC
|
||||
|
||||
from ollama import Client
|
||||
import dashscope
|
||||
from openai import OpenAI
|
||||
from FlagEmbedding import FlagModel
|
||||
import torch
|
||||
import numpy as np
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from rag.utils import num_tokens_from_string
|
||||
@ -150,3 +149,24 @@ class ZhipuEmbed(Base):
|
||||
res = self.client.embeddings.create(input=text,
|
||||
model=self.model_name)
|
||||
return np.array(res.data[0].embedding), res.usage.total_tokens
|
||||
|
||||
|
||||
class OllamaEmbed(Base):
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
self.client = Client(host=kwargs["base_url"])
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list, batch_size=32):
|
||||
arr = []
|
||||
tks_num = 0
|
||||
for txt in texts:
|
||||
res = self.client.embeddings(prompt=txt,
|
||||
model=self.model_name)
|
||||
arr.append(res["embedding"])
|
||||
tks_num += 128
|
||||
return np.array(arr), tks_num
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embeddings(prompt=text,
|
||||
model=self.model_name)
|
||||
return np.array(res["embedding"]), 128
|
||||
|
||||
Reference in New Issue
Block a user