Refa: replace Chat Ollama implementation with LiteLLM (#9693)

### What problem does this PR solve?

replace Chat Ollama implementation with LiteLLM.

### Type of change

- [x] Refactoring
This commit is contained in:
Yongteng Lei
2025-08-25 17:56:31 +08:00
committed by GitHub
parent d367c7e226
commit b6c1ca828e
2 changed files with 3 additions and 69 deletions

View File

@ -36,6 +36,7 @@ class SupportedLiteLLMProvider(StrEnum):
Nvidia = "NVIDIA" Nvidia = "NVIDIA"
TogetherAI = "TogetherAI" TogetherAI = "TogetherAI"
Anthropic = "Anthropic" Anthropic = "Anthropic"
Ollama = "Ollama"
FACTORY_DEFAULT_BASE_URL = { FACTORY_DEFAULT_BASE_URL = {
@ -59,6 +60,7 @@ LITELLM_PROVIDER_PREFIX = {
SupportedLiteLLMProvider.Nvidia: "nvidia_nim/", SupportedLiteLLMProvider.Nvidia: "nvidia_nim/",
SupportedLiteLLMProvider.TogetherAI: "together_ai/", SupportedLiteLLMProvider.TogetherAI: "together_ai/",
SupportedLiteLLMProvider.Anthropic: "", # don't need a prefix SupportedLiteLLMProvider.Anthropic: "", # don't need a prefix
SupportedLiteLLMProvider.Ollama: "ollama_chat/",
} }
ChatModel = globals().get("ChatModel", {}) ChatModel = globals().get("ChatModel", {})

View File

@ -29,7 +29,6 @@ import json_repair
import litellm import litellm
import openai import openai
import requests import requests
from ollama import Client
from openai import OpenAI from openai import OpenAI
from openai.lib.azure import AzureOpenAI from openai.lib.azure import AzureOpenAI
from strenum import StrEnum from strenum import StrEnum
@ -683,73 +682,6 @@ class ZhipuChat(Base):
return super().chat_streamly_with_tools(system, history, gen_conf) return super().chat_streamly_with_tools(system, history, gen_conf)
class OllamaChat(Base):
_FACTORY_NAME = "Ollama"
def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)
self.client = Client(host=base_url) if not key or key == "x" else Client(host=base_url, headers={"Authorization": f"Bearer {key}"})
self.model_name = model_name
self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1)))
def _clean_conf(self, gen_conf):
options = {}
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
for k in ["temperature", "top_p", "presence_penalty", "frequency_penalty"]:
if k not in gen_conf:
continue
options[k] = gen_conf[k]
return options
def _chat(self, history, gen_conf={}, **kwargs):
# Calculate context size
ctx_size = self._calculate_dynamic_ctx(history)
gen_conf["num_ctx"] = ctx_size
response = self.client.chat(model=self.model_name, messages=history, options=gen_conf, keep_alive=self.keep_alive)
ans = response["message"]["content"].strip()
token_count = response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
return ans, token_count
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
try:
# Calculate context size
ctx_size = self._calculate_dynamic_ctx(history)
options = {"num_ctx": ctx_size}
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf:
options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
ans = ""
try:
response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=self.keep_alive)
for resp in response:
if resp["done"]:
token_count = resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
yield token_count
ans = resp["message"]["content"]
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield 0
except Exception as e:
yield "**ERROR**: " + str(e)
yield 0
class LocalAIChat(Base): class LocalAIChat(Base):
_FACTORY_NAME = "LocalAI" _FACTORY_NAME = "LocalAI"
@ -1422,7 +1354,7 @@ class Ai302Chat(Base):
class LiteLLMBase(ABC): class LiteLLMBase(ABC):
_FACTORY_NAME = ["Tongyi-Qianwen", "Bedrock", "Moonshot", "xAI", "DeepInfra", "Groq", "Cohere", "Gemini", "DeepSeek", "NVIDIA", "TogetherAI", "Anthropic"] _FACTORY_NAME = ["Tongyi-Qianwen", "Bedrock", "Moonshot", "xAI", "DeepInfra", "Groq", "Cohere", "Gemini", "DeepSeek", "NVIDIA", "TogetherAI", "Anthropic", "Ollama"]
def __init__(self, key, model_name, base_url=None, **kwargs): def __init__(self, key, model_name, base_url=None, **kwargs):
self.timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600)) self.timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))