mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Fix: default model base url extraction logic (#11263)
### What problem does this PR solve?
Fixes an issue where default models which used the same factory but
different base URLs would all be initialised with the default chat
model's base URL and would ignore e.g. the embedding model's base URL
config.
For example, with the following service config, the embedding and
reranker models would end up using the base URL for the default chat
model (i.e. `llm1.example.com`):
```yaml
ragflow:
service_conf:
user_default_llm:
factory: OpenAI-API-Compatible
api_key: not-used
default_models:
chat_model:
name: llm1
base_url: https://llm1.example.com/v1
embedding_model:
name: llm2
base_url: https://llm2.example.com/v1
rerank_model:
name: llm3
base_url: https://llm3.example.com/v1/rerank
llm_factories:
factory_llm_infos:
- name: OpenAI-API-Compatible
logo: ""
tags: "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION"
status: "1"
llm:
- llm_name: llm1
base_url: 'https://llm1.example.com/v1'
api_key: not-used
tags: "LLM,CHAT,IMAGE2TEXT"
max_tokens: 100000
model_type: chat
is_tools: false
- llm_name: llm2
base_url: https://llm2.example.com/v1
api_key: not-used
tags: "TEXT EMBEDDING"
max_tokens: 10000
model_type: embedding
- llm_name: llm3
base_url: https://llm3.example.com/v1/rerank
api_key: not-used
tags: "RERANK,1k"
max_tokens: 10000
model_type: rerank
```
### Type of change
- [X] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -19,6 +19,7 @@ import re
|
|||||||
from common.token_utils import num_tokens_from_string
|
from common.token_utils import num_tokens_from_string
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
|
from common.constants import LLMType
|
||||||
from api.db.db_models import LLM
|
from api.db.db_models import LLM
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
|
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
|
||||||
@ -32,6 +33,14 @@ def get_init_tenant_llm(user_id):
|
|||||||
from common import settings
|
from common import settings
|
||||||
tenant_llm = []
|
tenant_llm = []
|
||||||
|
|
||||||
|
model_configs = {
|
||||||
|
LLMType.CHAT: settings.CHAT_CFG,
|
||||||
|
LLMType.EMBEDDING: settings.EMBEDDING_CFG,
|
||||||
|
LLMType.SPEECH2TEXT: settings.ASR_CFG,
|
||||||
|
LLMType.IMAGE2TEXT: settings.IMAGE2TEXT_CFG,
|
||||||
|
LLMType.RERANK: settings.RERANK_CFG,
|
||||||
|
}
|
||||||
|
|
||||||
seen = set()
|
seen = set()
|
||||||
factory_configs = []
|
factory_configs = []
|
||||||
for factory_config in [
|
for factory_config in [
|
||||||
@ -54,8 +63,8 @@ def get_init_tenant_llm(user_id):
|
|||||||
"llm_factory": factory_config["factory"],
|
"llm_factory": factory_config["factory"],
|
||||||
"llm_name": llm.llm_name,
|
"llm_name": llm.llm_name,
|
||||||
"model_type": llm.model_type,
|
"model_type": llm.model_type,
|
||||||
"api_key": factory_config["api_key"],
|
"api_key": model_configs.get(llm.model_type, {}).get("api_key", factory_config["api_key"]),
|
||||||
"api_base": factory_config["base_url"],
|
"api_base": model_configs.get(llm.model_type, {}).get("base_url", factory_config["base_url"]),
|
||||||
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
|
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -80,8 +89,8 @@ class LLMBundle(LLM4Tenant):
|
|||||||
|
|
||||||
def encode(self, texts: list):
|
def encode(self, texts: list):
|
||||||
if self.langfuse:
|
if self.langfuse:
|
||||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts})
|
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts})
|
||||||
|
|
||||||
safe_texts = []
|
safe_texts = []
|
||||||
for text in texts:
|
for text in texts:
|
||||||
token_size = num_tokens_from_string(text)
|
token_size = num_tokens_from_string(text)
|
||||||
@ -90,7 +99,7 @@ class LLMBundle(LLM4Tenant):
|
|||||||
safe_texts.append(text[:target_len])
|
safe_texts.append(text[:target_len])
|
||||||
else:
|
else:
|
||||||
safe_texts.append(text)
|
safe_texts.append(text)
|
||||||
|
|
||||||
embeddings, used_tokens = self.mdl.encode(safe_texts)
|
embeddings, used_tokens = self.mdl.encode(safe_texts)
|
||||||
|
|
||||||
llm_name = getattr(self, "llm_name", None)
|
llm_name = getattr(self, "llm_name", None)
|
||||||
|
|||||||
Reference in New Issue
Block a user