Add tavily as web searh tool. (#5349)

### What problem does this PR solve?

#5198

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-02-26 10:21:04 +08:00
committed by GitHub
parent e5e9ca0015
commit 53b9e7b52f
6 changed files with 3248 additions and 3080 deletions

View File

@ -64,6 +64,7 @@ def structure_answer(conv, ans, message_id, session_id):
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
"image_id": get_value(chunk, "image_id", "img_id"),
"positions": get_value(chunk, "positions", "position_int"),
"url": chunk.get("url")
} for chunk in reference.get("chunks", [])]
reference["chunks"] = chunk_list

View File

@ -40,6 +40,7 @@ from rag.nlp.search import index_name
from rag.settings import TAG_FLD
from rag.utils import rmSpace, num_tokens_from_string, encoder
from api.utils.file_utils import get_project_base_directory
from rag.utils.tavily_conn import Tavily
class DialogService(CommonService):
@ -125,6 +126,7 @@ def kb_prompt(kbinfos, max_tokens):
chunks_num += 1
if max_tokens * 0.97 < used_token_count:
knowledges = knowledges[:i]
logging.warning(f"Not all the retrieval into prompt: {i+1}/{len(knowledges)}")
break
docs = DocumentService.get_by_ids([ck["doc_id"] for ck in kbinfos["chunks"][:chunks_num]])
@ -132,7 +134,7 @@ def kb_prompt(kbinfos, max_tokens):
doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []})
for ck in kbinfos["chunks"][:chunks_num]:
doc2chunks[ck["docnm_kwd"]]["chunks"].append(ck["content_with_weight"])
doc2chunks[ck["docnm_kwd"]]["chunks"].append((f"URL: {ck['url']}\n" if "url" in ck else "") + ck["content_with_weight"])
doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {})
knowledges = []
@ -295,7 +297,7 @@ def chat(dialog, messages, stream=True, **kwargs):
knowledges = []
if prompt_config.get("reasoning", False):
for think in reasoning(kbinfos, " ".join(questions), chat_mdl, embd_mdl, tenant_ids, dialog.kb_ids, MAX_SEARCH_LIMIT=3):
for think in reasoning(kbinfos, " ".join(questions), chat_mdl, embd_mdl, tenant_ids, dialog.kb_ids, prompt_config, MAX_SEARCH_LIMIT=3):
if isinstance(think, str):
thought = think
knowledges = [t for t in think.split("\n") if t]
@ -309,6 +311,11 @@ def chat(dialog, messages, stream=True, **kwargs):
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
rank_feature=label_question(" ".join(questions), kbs)
)
if prompt_config.get("tavily_api_key"):
tav = Tavily(prompt_config["tavily_api_key"])
tav_res = tav.retrieve_chunks(" ".join(questions))
kbinfos["chunks"].extend(tav_res["chunks"])
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
if prompt_config.get("use_kg"):
ck = settings.kg_retrievaler.retrieval(" ".join(questions),
tenant_ids,
@ -852,7 +859,7 @@ Output:
def reasoning(chunk_info: dict, question: str, chat_mdl: LLMBundle, embd_mdl: LLMBundle,
tenant_ids: list[str], kb_ids: list[str], MAX_SEARCH_LIMIT: int = 3,
tenant_ids: list[str], kb_ids: list[str], prompt_config, MAX_SEARCH_LIMIT: int = 3,
top_n: int = 5, similarity_threshold: float = 0.4, vector_similarity_weight: float = 0.3):
BEGIN_SEARCH_QUERY = "<|begin_search_query|>"
END_SEARCH_QUERY = "<|end_search_query|>"
@ -1023,10 +1030,28 @@ def reasoning(chunk_info: dict, question: str, chat_mdl: LLMBundle, embd_mdl: LL
truncated_prev_reasoning += '...\n\n'
truncated_prev_reasoning = truncated_prev_reasoning.strip('\n')
# Retrieval procedure:
# 1. KB search
# 2. Web search (optional)
# 3. KG search (optional)
kbinfos = settings.retrievaler.retrieval(search_query, embd_mdl, tenant_ids, kb_ids, 1, top_n,
similarity_threshold,
vector_similarity_weight
)
if prompt_config.get("tavily_api_key", "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"):
tav = Tavily(prompt_config["tavily_api_key"])
tav_res = tav.retrieve_chunks(" ".join(search_query))
kbinfos["chunks"].extend(tav_res["chunks"])
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
if prompt_config.get("use_kg"):
ck = settings.kg_retrievaler.retrieval(search_query,
tenant_ids,
kb_ids,
embd_mdl,
chat_mdl)
if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck)
# Merge chunk info for citations
if not chunk_info["chunks"]:
for k in chunk_info.keys():
@ -1048,7 +1073,7 @@ def reasoning(chunk_info: dict, question: str, chat_mdl: LLMBundle, embd_mdl: LL
relevant_extraction_prompt.format(
prev_reasoning=truncated_prev_reasoning,
search_query=search_query,
document="\n".join(kb_prompt(kbinfos, 512))
document="\n".join(kb_prompt(kbinfos, 4096))
),
[{"role": "user",
"content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}],