mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refa: replace Chat Ollama implementation with LiteLLM (#9693)
### What problem does this PR solve? replace Chat Ollama implementation with LiteLLM. ### Type of change - [x] Refactoring
This commit is contained in:
@ -36,6 +36,7 @@ class SupportedLiteLLMProvider(StrEnum):
|
|||||||
Nvidia = "NVIDIA"
|
Nvidia = "NVIDIA"
|
||||||
TogetherAI = "TogetherAI"
|
TogetherAI = "TogetherAI"
|
||||||
Anthropic = "Anthropic"
|
Anthropic = "Anthropic"
|
||||||
|
Ollama = "Ollama"
|
||||||
|
|
||||||
|
|
||||||
FACTORY_DEFAULT_BASE_URL = {
|
FACTORY_DEFAULT_BASE_URL = {
|
||||||
@ -59,6 +60,7 @@ LITELLM_PROVIDER_PREFIX = {
|
|||||||
SupportedLiteLLMProvider.Nvidia: "nvidia_nim/",
|
SupportedLiteLLMProvider.Nvidia: "nvidia_nim/",
|
||||||
SupportedLiteLLMProvider.TogetherAI: "together_ai/",
|
SupportedLiteLLMProvider.TogetherAI: "together_ai/",
|
||||||
SupportedLiteLLMProvider.Anthropic: "", # don't need a prefix
|
SupportedLiteLLMProvider.Anthropic: "", # don't need a prefix
|
||||||
|
SupportedLiteLLMProvider.Ollama: "ollama_chat/",
|
||||||
}
|
}
|
||||||
|
|
||||||
ChatModel = globals().get("ChatModel", {})
|
ChatModel = globals().get("ChatModel", {})
|
||||||
|
|||||||
@ -29,7 +29,6 @@ import json_repair
|
|||||||
import litellm
|
import litellm
|
||||||
import openai
|
import openai
|
||||||
import requests
|
import requests
|
||||||
from ollama import Client
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from openai.lib.azure import AzureOpenAI
|
from openai.lib.azure import AzureOpenAI
|
||||||
from strenum import StrEnum
|
from strenum import StrEnum
|
||||||
@ -683,73 +682,6 @@ class ZhipuChat(Base):
|
|||||||
return super().chat_streamly_with_tools(system, history, gen_conf)
|
return super().chat_streamly_with_tools(system, history, gen_conf)
|
||||||
|
|
||||||
|
|
||||||
class OllamaChat(Base):
|
|
||||||
_FACTORY_NAME = "Ollama"
|
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
|
||||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
|
||||||
|
|
||||||
self.client = Client(host=base_url) if not key or key == "x" else Client(host=base_url, headers={"Authorization": f"Bearer {key}"})
|
|
||||||
self.model_name = model_name
|
|
||||||
self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1)))
|
|
||||||
|
|
||||||
def _clean_conf(self, gen_conf):
|
|
||||||
options = {}
|
|
||||||
if "max_tokens" in gen_conf:
|
|
||||||
options["num_predict"] = gen_conf["max_tokens"]
|
|
||||||
for k in ["temperature", "top_p", "presence_penalty", "frequency_penalty"]:
|
|
||||||
if k not in gen_conf:
|
|
||||||
continue
|
|
||||||
options[k] = gen_conf[k]
|
|
||||||
return options
|
|
||||||
|
|
||||||
def _chat(self, history, gen_conf={}, **kwargs):
|
|
||||||
# Calculate context size
|
|
||||||
ctx_size = self._calculate_dynamic_ctx(history)
|
|
||||||
|
|
||||||
gen_conf["num_ctx"] = ctx_size
|
|
||||||
response = self.client.chat(model=self.model_name, messages=history, options=gen_conf, keep_alive=self.keep_alive)
|
|
||||||
ans = response["message"]["content"].strip()
|
|
||||||
token_count = response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
|
|
||||||
return ans, token_count
|
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
|
||||||
if system:
|
|
||||||
history.insert(0, {"role": "system", "content": system})
|
|
||||||
if "max_tokens" in gen_conf:
|
|
||||||
del gen_conf["max_tokens"]
|
|
||||||
try:
|
|
||||||
# Calculate context size
|
|
||||||
ctx_size = self._calculate_dynamic_ctx(history)
|
|
||||||
options = {"num_ctx": ctx_size}
|
|
||||||
if "temperature" in gen_conf:
|
|
||||||
options["temperature"] = gen_conf["temperature"]
|
|
||||||
if "max_tokens" in gen_conf:
|
|
||||||
options["num_predict"] = gen_conf["max_tokens"]
|
|
||||||
if "top_p" in gen_conf:
|
|
||||||
options["top_p"] = gen_conf["top_p"]
|
|
||||||
if "presence_penalty" in gen_conf:
|
|
||||||
options["presence_penalty"] = gen_conf["presence_penalty"]
|
|
||||||
if "frequency_penalty" in gen_conf:
|
|
||||||
options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
|
||||||
|
|
||||||
ans = ""
|
|
||||||
try:
|
|
||||||
response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=self.keep_alive)
|
|
||||||
for resp in response:
|
|
||||||
if resp["done"]:
|
|
||||||
token_count = resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
|
|
||||||
yield token_count
|
|
||||||
ans = resp["message"]["content"]
|
|
||||||
yield ans
|
|
||||||
except Exception as e:
|
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
|
||||||
yield 0
|
|
||||||
except Exception as e:
|
|
||||||
yield "**ERROR**: " + str(e)
|
|
||||||
yield 0
|
|
||||||
|
|
||||||
|
|
||||||
class LocalAIChat(Base):
|
class LocalAIChat(Base):
|
||||||
_FACTORY_NAME = "LocalAI"
|
_FACTORY_NAME = "LocalAI"
|
||||||
|
|
||||||
@ -1422,7 +1354,7 @@ class Ai302Chat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class LiteLLMBase(ABC):
|
class LiteLLMBase(ABC):
|
||||||
_FACTORY_NAME = ["Tongyi-Qianwen", "Bedrock", "Moonshot", "xAI", "DeepInfra", "Groq", "Cohere", "Gemini", "DeepSeek", "NVIDIA", "TogetherAI", "Anthropic"]
|
_FACTORY_NAME = ["Tongyi-Qianwen", "Bedrock", "Moonshot", "xAI", "DeepInfra", "Groq", "Cohere", "Gemini", "DeepSeek", "NVIDIA", "TogetherAI", "Anthropic", "Ollama"]
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||||
self.timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
|
self.timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
|
||||||
|
|||||||
Reference in New Issue
Block a user