Feat: support tree structured deep-research policy. (#12559)

### What problem does this PR solve?

#12558
### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2026-01-13 09:41:35 +08:00
committed by GitHub
parent 867ec94258
commit 44bada64c9
15 changed files with 1166 additions and 1381 deletions

View File

@ -0,0 +1,20 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from .tree_structured_query_decomposition_retrieval import TreeStructuredQueryDecompositionRetrieval as DeepResearcher
__all__ = ['DeepResearcher']

View File

@ -0,0 +1,126 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import logging
from functools import partial
from api.db.services.llm_service import LLMBundle
from rag.prompts import kb_prompt
from rag.prompts.generator import sufficiency_check, multi_queries_gen
from rag.utils.tavily_conn import Tavily
from timeit import default_timer as timer
class TreeStructuredQueryDecompositionRetrieval:
def __init__(self,
chat_mdl: LLMBundle,
prompt_config: dict,
kb_retrieve: partial = None,
kg_retrieve: partial = None
):
self.chat_mdl = chat_mdl
self.prompt_config = prompt_config
self._kb_retrieve = kb_retrieve
self._kg_retrieve = kg_retrieve
self._lock = asyncio.Lock()
def _retrieve_information(self, search_query):
"""Retrieve information from different sources"""
# 1. Knowledge base retrieval
kbinfos = []
try:
kbinfos = self._kb_retrieve(question=search_query) if self._kb_retrieve else {"chunks": [], "doc_aggs": []}
except Exception as e:
logging.error(f"Knowledge base retrieval error: {e}")
# 2. Web retrieval (if Tavily API is configured)
try:
if self.prompt_config.get("tavily_api_key"):
tav = Tavily(self.prompt_config["tavily_api_key"])
tav_res = tav.retrieve_chunks(search_query)
kbinfos["chunks"].extend(tav_res["chunks"])
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
except Exception as e:
logging.error(f"Web retrieval error: {e}")
# 3. Knowledge graph retrieval (if configured)
try:
if self.prompt_config.get("use_kg") and self._kg_retrieve:
ck = self._kg_retrieve(question=search_query)
if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck)
except Exception as e:
logging.error(f"Knowledge graph retrieval error: {e}")
return kbinfos
async def _async_update_chunk_info(self, chunk_info, kbinfos):
async with self._lock:
"""Update chunk information for citations"""
if not chunk_info["chunks"]:
# If this is the first retrieval, use the retrieval results directly
for k in chunk_info.keys():
chunk_info[k] = kbinfos[k]
else:
# Merge newly retrieved information, avoiding duplicates
cids = [c["chunk_id"] for c in chunk_info["chunks"]]
for c in kbinfos["chunks"]:
if c["chunk_id"] not in cids:
chunk_info["chunks"].append(c)
dids = [d["doc_id"] for d in chunk_info["doc_aggs"]]
for d in kbinfos["doc_aggs"]:
if d["doc_id"] not in dids:
chunk_info["doc_aggs"].append(d)
async def research(self, chunk_info, question, query, depth=3, callback=None):
if callback:
await callback("<START_DEEP_RESEARCH>")
await self._research(chunk_info, question, query, depth, callback)
if callback:
await callback("<END_DEEP_RESEARCH>")
async def _research(self, chunk_info, question, query, depth=3, callback=None):
if depth == 0:
#if callback:
# await callback("Reach the max search depth.")
return ""
if callback:
await callback(f"Searching by `{query}`...")
st = timer()
ret = self._retrieve_information(query)
if callback:
await callback("Retrieval %d results by %.1fms"%(len(ret["chunks"]), (timer()-st)*1000))
await self._async_update_chunk_info(chunk_info, ret)
ret = kb_prompt(ret, self.chat_mdl.max_length*0.5)
if callback:
await callback("Checking the sufficiency for retrieved information.")
suff = await sufficiency_check(self.chat_mdl, question, ret)
if suff["is_sufficient"]:
if callback:
await callback("Yes, it's sufficient.")
return ret
#if callback:
# await callback("The retrieved information is not sufficient. Planing next steps...")
succ_question_info = await multi_queries_gen(self.chat_mdl, question, query, suff["missing_information"], ret)
if callback:
await callback("Next step is to search for the following questions:\n" + "\n - ".join(step["question"] for step in succ_question_info["questions"]))
steps = []
for step in succ_question_info["questions"]:
steps.append(asyncio.create_task(self._research(chunk_info, step["question"], step["query"], depth-1, callback)))
results = await asyncio.gather(*steps, return_exceptions=True)
return "\n".join([str(r) for r in results])

View File

@ -382,6 +382,7 @@ class Dealer:
# Ensure RERANK_LIMIT is multiple of page_size
RERANK_LIMIT = math.ceil(64 / page_size) * page_size if page_size > 1 else 1
RERANK_LIMIT = max(30, RERANK_LIMIT)
req = {
"kb_ids": kb_ids,
"doc_ids": doc_ids,

View File

@ -38,7 +38,7 @@ def get_value(d, k1, k2):
def chunks_format(reference):
if not reference or (reference is not dict):
if not reference or not isinstance(reference, dict):
return []
return [
{
@ -485,20 +485,26 @@ async def gen_meta_filter(chat_mdl, meta_data: dict, query: str) -> dict:
return {"conditions": []}
async def gen_json(system_prompt: str, user_prompt: str, chat_mdl, gen_conf=None):
async def gen_json(system_prompt: str, user_prompt: str, chat_mdl, gen_conf={}, max_retry=2):
from graphrag.utils import get_llm_cache, set_llm_cache
cached = get_llm_cache(chat_mdl.llm_name, system_prompt, user_prompt, gen_conf)
if cached:
return json_repair.loads(cached)
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:], gen_conf=gen_conf)
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
try:
res = json_repair.loads(ans)
set_llm_cache(chat_mdl.llm_name, system_prompt, ans, user_prompt, gen_conf)
return res
except Exception:
logging.exception(f"Loading json failure: {ans}")
err = ""
ans = ""
for _ in range(max_retry):
if ans and err:
msg[-1]["content"] += f"\nGenerated JSON is as following:\n{ans}\nBut exception while loading:\n{err}\nPlease reconsider and correct it."
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:], gen_conf=gen_conf)
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
try:
res = json_repair.loads(ans)
set_llm_cache(chat_mdl.llm_name, system_prompt, ans, user_prompt, gen_conf)
return res
except Exception as e:
logging.exception(f"Loading json failure: {ans}")
err += str(e)
TOC_DETECTION = load_prompt("toc_detection")
@ -847,8 +853,6 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
TOC_RELEVANCE_SYSTEM = load_prompt("toc_relevance_system")
TOC_RELEVANCE_USER = load_prompt("toc_relevance_user")
async def relevant_chunks_with_toc(query: str, toc: list[dict], chat_mdl, topn: int = 6):
import numpy as np
try:
@ -876,8 +880,6 @@ async def relevant_chunks_with_toc(query: str, toc: list[dict], chat_mdl, topn:
META_DATA = load_prompt("meta_data")
async def gen_metadata(chat_mdl, schema: dict, content: str):
template = PROMPT_JINJA_ENV.from_string(META_DATA)
for k, desc in schema["properties"].items():
@ -890,3 +892,34 @@ async def gen_metadata(chat_mdl, schema: dict, content: str):
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:])
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
SUFFICIENCY_CHECK = load_prompt("sufficiency_check")
async def sufficiency_check(chat_mdl, question: str, ret_content: str):
try:
return await gen_json(
PROMPT_JINJA_ENV.from_string(SUFFICIENCY_CHECK).render(question=question, retrieved_docs=ret_content),
"Output:\n",
chat_mdl
)
except Exception as e:
logging.exception(e)
return {}
MULTI_QUERIES_GEN = load_prompt("multi_queries_gen")
async def multi_queries_gen(chat_mdl, question: str, query:str, missing_infos:list[str], ret_content: str):
try:
return await gen_json(
PROMPT_JINJA_ENV.from_string(MULTI_QUERIES_GEN).render(
original_question=question,
original_query=query,
missing_info="\n - ".join(missing_infos),
retrieved_docs=ret_content
),
"Output:\n",
chat_mdl
)
except Exception as e:
logging.exception(e)
return {}

View File

@ -0,0 +1,41 @@
You are a query optimization expert.
The user's original query failed to retrieve sufficient information;
please generate multiple complementary improved questions and corresponding queries.
Original query:
{{ original_query }}
Original question:
{{ original_question }}
Currently, retrieved content:
{{ retrieved_docs }}
Missing information:
{{ missing_info }}
Please generate 2-3 complementary queries to help find the missing information. These queries should:
1. Focus on different missing information points.
2. Use different expressions.
3. Avoid being identical to the original query.
4. Remain concise and clear.
Output format (JSON):
```json
{
"reasoning": "Explanation of query generation strategy",
"questions": [
{"question": "Improved question 1", "query": "Improved query 1"},
{"question": "Improved question 2", "query": "Improved query 2"},
{"question": "Improved question 3", "query": "Improved query 3"}
]
}
```
Requirements:
1. Questions array contains 1-3 questions and corresponding queries.
2. Each question length is between 5-200 characters.
3. Each query length is between 1-5 keywords.
4. Each query MUST be in the same language as the retrieved content in.
5. DO NOT generate question and query that is similar to the original query.
6. Reasoning explains the generation strategy.

View File

@ -0,0 +1,24 @@
You are a information retrieval evaluation expert. Please assess whether the currently retrieved content is sufficient to answer the user's question.
User question:
{{ question }}
Retrieved content:
{{ retrieved_docs }}
Please determine whether these content are sufficient to answer the user's question.
Output format (JSON):
```json
{
"is_sufficient": true/false,
"reasoning": "Your reasoning for the judgment",
"missing_information": ["Missing information 1", "Missing information 2"]
}
```
Requirements:
1. If the retrieved content contains key information needed to answer the query, judge as sufficient (true).
2. If key information is missing, judge as insufficient (false), and list the missing information.
3. The `reasoning` should be concise and clear.
4. The `missing_information` should only be filled when insufficient, otherwise empty array.