Fix typos: retrievaler -> retriever (#10372)

### What problem does this PR solve?

Fix typos

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2025-10-10 09:17:36 +08:00
committed by GitHub
parent f4324e89d9
commit d931c33ced
37 changed files with 438 additions and 176 deletions

View File

@ -370,7 +370,7 @@ def chat(dialog, messages, stream=True, **kwargs):
chat_mdl.bind_tools(toolcall_session, tools)
bind_models_ts = timer()
retriever = settings.retrievaler
retriever = settings.retriever
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else []
if "doc_ids" in messages[-1]:
@ -472,7 +472,7 @@ def chat(dialog, messages, stream=True, **kwargs):
kbinfos["chunks"].extend(tav_res["chunks"])
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
if prompt_config.get("use_kg"):
ck = settings.kg_retrievaler.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
ck = settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
LLMBundle(dialog.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck)
@ -658,7 +658,7 @@ Please write the SQL, only SQL, without any other explanations or text.
logging.debug(f"{question} get SQL(refined): {sql}")
tried_times += 1
return settings.retrievaler.sql_retrieval(sql, format="json"), sql
return settings.retriever.sql_retrieval(sql, format="json"), sql
tbl, sql = get_table()
if tbl is None:
@ -752,7 +752,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
embedding_list = list(set([kb.embd_id for kb in kbs]))
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
retriever = settings.retriever if not is_knowledge_graph else settings.kg_retriever
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name)
@ -848,7 +848,7 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
if not doc_ids:
doc_ids = None
ranks = settings.retrievaler.retrieval(
ranks = settings.retriever.retrieval(
question=question,
embd_mdl=embd_mdl,
tenant_ids=tenant_ids,

View File

@ -33,7 +33,8 @@ class MCPServerService(CommonService):
@classmethod
@DB.connection_context()
def get_servers(cls, tenant_id: str, id_list: list[str] | None, page_number, items_per_page, orderby, desc, keywords):
def get_servers(cls, tenant_id: str, id_list: list[str] | None, page_number, items_per_page, orderby, desc,
keywords):
"""Retrieve all MCP servers associated with a tenant.
This method fetches all MCP servers for a given tenant, ordered by creation time.

View File

@ -94,7 +94,8 @@ class SearchService(CommonService):
query = (
cls.model.select(*fields)
.join(User, on=(cls.model.tenant_id == User.id))
.where(((cls.model.tenant_id.in_(joined_tenant_ids)) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value))
.where(((cls.model.tenant_id.in_(joined_tenant_ids)) | (cls.model.tenant_id == user_id)) & (
cls.model.status == StatusEnum.VALID.value))
)
if keywords:

View File

@ -165,7 +165,7 @@ class TaskService(CommonService):
]
tasks = (
cls.model.select(*fields).order_by(cls.model.from_page.asc(), cls.model.create_time.desc())
.where(cls.model.doc_id == doc_id)
.where(cls.model.doc_id == doc_id)
)
tasks = list(tasks.dicts())
if not tasks:
@ -205,18 +205,18 @@ class TaskService(CommonService):
cls.model.select(
*[Document.id, Document.kb_id, Document.location, File.parent_id]
)
.join(Document, on=(cls.model.doc_id == Document.id))
.join(
.join(Document, on=(cls.model.doc_id == Document.id))
.join(
File2Document,
on=(File2Document.document_id == Document.id),
join_type=JOIN.LEFT_OUTER,
)
.join(
.join(
File,
on=(File2Document.file_id == File.id),
join_type=JOIN.LEFT_OUTER,
)
.where(
.where(
Document.status == StatusEnum.VALID.value,
Document.run == TaskStatus.RUNNING.value,
~(Document.type == FileType.VIRTUAL.value),
@ -294,8 +294,8 @@ class TaskService(CommonService):
cls.model.update(progress=prog).where(
(cls.model.id == id) &
(
(cls.model.progress != -1) &
((prog == -1) | (prog > cls.model.progress))
(cls.model.progress != -1) &
((prog == -1) | (prog > cls.model.progress))
)
).execute()
else:
@ -343,6 +343,7 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
- Task digests are calculated for optimization and reuse
- Previous task chunks may be reused if available
"""
def new_task():
return {
"id": get_uuid(),
@ -515,7 +516,7 @@ def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str=CANVAS_DE
task["file"] = file
if not REDIS_CONN.queue_product(
get_svr_queue_name(priority), message=task
get_svr_queue_name(priority), message=task
):
return False, "Can't access Redis. Please check the Redis' status."

View File

@ -57,8 +57,10 @@ class TenantLLMService(CommonService):
@classmethod
@DB.connection_context()
def get_my_llms(cls, tenant_id):
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens]
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name,
cls.model.used_tokens]
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
return list(objs)
@ -122,7 +124,8 @@ class TenantLLMService(CommonService):
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
if not model_config:
if mdlnm == "flag-embedding":
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""}
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name,
"api_base": ""}
else:
if not mdlnm:
raise LookupError(f"Type of {llm_type} model is not set.")
@ -137,27 +140,33 @@ class TenantLLMService(CommonService):
if llm_type == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel:
return
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
base_url=model_config["api_base"])
if llm_type == LLMType.RERANK:
if model_config["llm_factory"] not in RerankModel:
return
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
base_url=model_config["api_base"])
if llm_type == LLMType.IMAGE2TEXT.value:
if model_config["llm_factory"] not in CvModel:
return
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang,
base_url=model_config["api_base"], **kwargs)
if llm_type == LLMType.CHAT.value:
if model_config["llm_factory"] not in ChatModel:
return
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
base_url=model_config["api_base"], **kwargs)
if llm_type == LLMType.SPEECH2TEXT:
if model_config["llm_factory"] not in Seq2txtModel:
return
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"],
model_name=model_config["llm_name"], lang=lang,
base_url=model_config["api_base"])
if llm_type == LLMType.TTS:
if model_config["llm_factory"] not in TTSModel:
return
@ -194,11 +203,14 @@ class TenantLLMService(CommonService):
try:
num = (
cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True)
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name,
cls.model.llm_factory == llm_factory if llm_factory else True)
.execute()
)
except Exception:
logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name)
logging.exception(
"TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s",
tenant_id, llm_name)
return 0
return num
@ -206,7 +218,9 @@ class TenantLLMService(CommonService):
@classmethod
@DB.connection_context()
def get_openai_models(cls):
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"),
~(cls.model.llm_name == "text-embedding-3-small"),
~(cls.model.llm_name == "text-embedding-3-large")).dicts()
return list(objs)
@classmethod
@ -250,8 +264,9 @@ class LLM4Tenant:
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
self.langfuse = None
if langfuse_keys:
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key,
host=langfuse_keys.host)
if langfuse.auth_check():
self.langfuse = langfuse
trace_id = self.langfuse.create_trace_id()
self.trace_context = {"trace_id": trace_id}
self.trace_context = {"trace_id": trace_id}

View File

@ -2,22 +2,22 @@ from api.db.db_models import UserCanvasVersion, DB
from api.db.services.common_service import CommonService
from peewee import DoesNotExist
class UserCanvasVersionService(CommonService):
model = UserCanvasVersion
@classmethod
@DB.connection_context()
def list_by_canvas_id(cls, user_canvas_id):
try:
user_canvas_version = cls.model.select(
*[cls.model.id,
cls.model.create_time,
cls.model.title,
cls.model.create_date,
cls.model.update_date,
cls.model.user_canvas_id,
cls.model.update_time]
*[cls.model.id,
cls.model.create_time,
cls.model.title,
cls.model.create_date,
cls.model.update_date,
cls.model.user_canvas_id,
cls.model.update_time]
).where(cls.model.user_canvas_id == user_canvas_id)
return user_canvas_version
except DoesNotExist:
@ -46,18 +46,16 @@ class UserCanvasVersionService(CommonService):
@DB.connection_context()
def delete_all_versions(cls, user_canvas_id):
try:
user_canvas_version = cls.model.select().where(cls.model.user_canvas_id == user_canvas_id).order_by(cls.model.create_time.desc())
user_canvas_version = cls.model.select().where(cls.model.user_canvas_id == user_canvas_id).order_by(
cls.model.create_time.desc())
if user_canvas_version.count() > 20:
delete_ids = []
for i in range(20, user_canvas_version.count()):
delete_ids.append(user_canvas_version[i].id)
cls.delete_by_ids(delete_ids)
return True
except DoesNotExist:
return None
except Exception:
return None