Files
ragflow/api/db/services/llm_service.py
Scott Davidson 6b64641042 Fix: default model base url extraction logic (#11263)
### What problem does this PR solve?

Fixes an issue where default models which used the same factory but
different base URLs would all be initialised with the default chat
model's base URL and would ignore e.g. the embedding model's base URL
config.

For example, with the following service config, the embedding and
reranker models would end up using the base URL for the default chat
model (i.e. `llm1.example.com`):

```yaml
ragflow:
  service_conf:
    user_default_llm:
      factory: OpenAI-API-Compatible
      api_key: not-used
      default_models:
        chat_model:
          name: llm1
          base_url: https://llm1.example.com/v1
        embedding_model:
          name: llm2
          base_url: https://llm2.example.com/v1
        rerank_model:
          name: llm3
          base_url: https://llm3.example.com/v1/rerank

  llm_factories:
    factory_llm_infos:
    - name: OpenAI-API-Compatible
      logo: ""
      tags: "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION"
      status: "1"
      llm:
        - llm_name: llm1
          base_url: 'https://llm1.example.com/v1'
          api_key: not-used
          tags: "LLM,CHAT,IMAGE2TEXT"
          max_tokens: 100000
          model_type: chat
          is_tools: false

        - llm_name: llm2
          base_url: https://llm2.example.com/v1
          api_key: not-used
          tags: "TEXT EMBEDDING"
          max_tokens: 10000
          model_type: embedding

        - llm_name: llm3
          base_url: https://llm3.example.com/v1/rerank
          api_key: not-used
          tags: "RERANK,1k"
          max_tokens: 10000
          model_type: rerank
```

### Type of change

- [X] Bug Fix (non-breaking change which fixes an issue)
2025-11-17 14:21:27 +08:00

284 lines
12 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import inspect
import logging
import re
from common.token_utils import num_tokens_from_string
from functools import partial
from typing import Generator
from common.constants import LLMType
from api.db.db_models import LLM
from api.db.services.common_service import CommonService
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
class LLMService(CommonService):
model = LLM
def get_init_tenant_llm(user_id):
from common import settings
tenant_llm = []
model_configs = {
LLMType.CHAT: settings.CHAT_CFG,
LLMType.EMBEDDING: settings.EMBEDDING_CFG,
LLMType.SPEECH2TEXT: settings.ASR_CFG,
LLMType.IMAGE2TEXT: settings.IMAGE2TEXT_CFG,
LLMType.RERANK: settings.RERANK_CFG,
}
seen = set()
factory_configs = []
for factory_config in [
settings.CHAT_CFG,
settings.EMBEDDING_CFG,
settings.ASR_CFG,
settings.IMAGE2TEXT_CFG,
settings.RERANK_CFG,
]:
factory_name = factory_config["factory"]
if factory_name not in seen:
seen.add(factory_name)
factory_configs.append(factory_config)
for factory_config in factory_configs:
for llm in LLMService.query(fid=factory_config["factory"]):
tenant_llm.append(
{
"tenant_id": user_id,
"llm_factory": factory_config["factory"],
"llm_name": llm.llm_name,
"model_type": llm.model_type,
"api_key": model_configs.get(llm.model_type, {}).get("api_key", factory_config["api_key"]),
"api_base": model_configs.get(llm.model_type, {}).get("base_url", factory_config["base_url"]),
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
}
)
unique = {}
for item in tenant_llm:
key = (item["tenant_id"], item["llm_factory"], item["llm_name"])
if key not in unique:
unique[key] = item
return list(unique.values())
class LLMBundle(LLM4Tenant):
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
super().__init__(tenant_id, llm_type, llm_name, lang, **kwargs)
def bind_tools(self, toolcall_session, tools):
if not self.is_tools:
logging.warning(f"Model {self.llm_name} does not support tool call, but you have assigned one or more tools to it!")
return
self.mdl.bind_tools(toolcall_session, tools)
def encode(self, texts: list):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts})
safe_texts = []
for text in texts:
token_size = num_tokens_from_string(text)
if token_size > self.max_length:
target_len = int(self.max_length * 0.95)
safe_texts.append(text[:target_len])
else:
safe_texts.append(text)
embeddings, used_tokens = self.mdl.encode(safe_texts)
llm_name = getattr(self, "llm_name", None)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name):
logging.error("LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
if self.langfuse:
generation.update(usage_details={"total_tokens": used_tokens})
generation.end()
return embeddings, used_tokens
def encode_queries(self, query: str):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode_queries", model=self.llm_name, input={"query": query})
emd, used_tokens = self.mdl.encode_queries(query)
llm_name = getattr(self, "llm_name", None)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name):
logging.error("LLMBundle.encode_queries can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
if self.langfuse:
generation.update(usage_details={"total_tokens": used_tokens})
generation.end()
return emd, used_tokens
def similarity(self, query: str, texts: list):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="similarity", model=self.llm_name, input={"query": query, "texts": texts})
sim, used_tokens = self.mdl.similarity(query, texts)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
logging.error("LLMBundle.similarity can't update token usage for {}/RERANK used_tokens: {}".format(self.tenant_id, used_tokens))
if self.langfuse:
generation.update(usage_details={"total_tokens": used_tokens})
generation.end()
return sim, used_tokens
def describe(self, image, max_tokens=300):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe", metadata={"model": self.llm_name})
txt, used_tokens = self.mdl.describe(image)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
if self.langfuse:
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
generation.end()
return txt
def describe_with_prompt(self, image, prompt):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe_with_prompt", metadata={"model": self.llm_name, "prompt": prompt})
txt, used_tokens = self.mdl.describe_with_prompt(image, prompt)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
if self.langfuse:
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
generation.end()
return txt
def transcription(self, audio):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="transcription", metadata={"model": self.llm_name})
txt, used_tokens = self.mdl.transcription(audio)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
logging.error("LLMBundle.transcription can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens))
if self.langfuse:
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
generation.end()
return txt
def tts(self, text: str) -> Generator[bytes, None, None]:
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="tts", input={"text": text})
for chunk in self.mdl.tts(text):
if isinstance(chunk, int):
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, chunk, self.llm_name):
logging.error("LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
return
yield chunk
if self.langfuse:
generation.end()
def _remove_reasoning_content(self, txt: str) -> str:
first_think_start = txt.find("<think>")
if first_think_start == -1:
return txt
last_think_end = txt.rfind("</think>")
if last_think_end == -1:
return txt
if last_think_end < first_think_start:
return txt
return txt[last_think_end + len("</think>") :]
@staticmethod
def _clean_param(chat_partial, **kwargs):
func = chat_partial.func
sig = inspect.signature(func)
support_var_args = False
allowed_params = set()
for param in sig.parameters.values():
if param.kind == inspect.Parameter.VAR_KEYWORD:
support_var_args = True
elif param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY):
allowed_params.add(param.name)
if support_var_args:
return kwargs
else:
return {k: v for k, v in kwargs.items() if k in allowed_params}
def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str:
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
if self.is_tools and self.mdl.is_tools:
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
use_kwargs = self._clean_param(chat_partial, **kwargs)
txt, used_tokens = chat_partial(**use_kwargs)
txt = self._remove_reasoning_content(txt)
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
if self.langfuse:
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
generation.end()
return txt
def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
ans = ""
chat_partial = partial(self.mdl.chat_streamly, system, history, gen_conf)
total_tokens = 0
if self.is_tools and self.mdl.is_tools:
chat_partial = partial(self.mdl.chat_streamly_with_tools, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
for txt in chat_partial(**use_kwargs):
if isinstance(txt, int):
total_tokens = txt
if self.langfuse:
generation.update(output={"output": ans})
generation.end()
break
if txt.endswith("</think>"):
ans = ans[: -len("</think>")]
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
ans += txt
yield ans
if total_tokens > 0:
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name):
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))