diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index c5f31cf21..ca3e390ce 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -960,8 +960,9 @@ async def do_handle_task(task): task_tenant_id = task["tenant_id"] task_embedding_id = task["embd_id"] task_language = task["language"] - task_llm_id = task["parser_config"].get("llm_id") or task["llm_id"] - task["llm_id"] = task_llm_id + doc_task_llm_id = task["parser_config"].get("llm_id") or task["llm_id"] + kb_task_llm_id = task['kb_parser_config'].get("llm_id") or task["llm_id"] + task['llm_id'] = kb_task_llm_id task_dataset_id = task["kb_id"] task_doc_id = task["doc_id"] task_document_name = task["name"] @@ -1032,7 +1033,7 @@ async def do_handle_task(task): return # bind LLM for raptor - chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) + chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=kb_task_llm_id, lang=task_language) # run RAPTOR async with kg_limiter: chunks, token_count = await run_raptor_for_kb( @@ -1076,7 +1077,7 @@ async def do_handle_task(task): graphrag_conf = kb_parser_config.get("graphrag", {}) start_ts = timer() - chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) + chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=kb_task_llm_id, lang=task_language) with_resolution = graphrag_conf.get("resolution", False) with_community = graphrag_conf.get("community", False) async with kg_limiter: @@ -1101,6 +1102,7 @@ async def do_handle_task(task): return else: # Standard chunking methods + task['llm_id'] = doc_task_llm_id start_ts = timer() chunks = await build_chunks(task, progress_callback) logging.info("Build document {}: {:.2f}s".format(task_document_name, timer() - start_ts))