mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-23 03:26:53 +08:00
Refa: async retrieval process. (#12629)
### Type of change - [x] Refactoring - [x] Performance Improvement
This commit is contained in:
@ -36,12 +36,12 @@ class TreeStructuredQueryDecompositionRetrieval:
|
||||
self._kg_retrieve = kg_retrieve
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def _retrieve_information(self, search_query):
|
||||
async def _retrieve_information(self, search_query):
|
||||
"""Retrieve information from different sources"""
|
||||
# 1. Knowledge base retrieval
|
||||
kbinfos = []
|
||||
try:
|
||||
kbinfos = self._kb_retrieve(question=search_query) if self._kb_retrieve else {"chunks": [], "doc_aggs": []}
|
||||
kbinfos = await self._kb_retrieve(question=search_query) if self._kb_retrieve else {"chunks": [], "doc_aggs": []}
|
||||
except Exception as e:
|
||||
logging.error(f"Knowledge base retrieval error: {e}")
|
||||
|
||||
@ -58,7 +58,7 @@ class TreeStructuredQueryDecompositionRetrieval:
|
||||
# 3. Knowledge graph retrieval (if configured)
|
||||
try:
|
||||
if self.prompt_config.get("use_kg") and self._kg_retrieve:
|
||||
ck = self._kg_retrieve(question=search_query)
|
||||
ck = await self._kg_retrieve(question=search_query)
|
||||
if ck["content_with_weight"]:
|
||||
kbinfos["chunks"].insert(0, ck)
|
||||
except Exception as e:
|
||||
@ -100,9 +100,9 @@ class TreeStructuredQueryDecompositionRetrieval:
|
||||
if callback:
|
||||
await callback(f"Searching by `{query}`...")
|
||||
st = timer()
|
||||
ret = self._retrieve_information(query)
|
||||
ret = await self._retrieve_information(query)
|
||||
if callback:
|
||||
await callback("Retrieval %d results by %.1fms"%(len(ret["chunks"]), (timer()-st)*1000))
|
||||
await callback("Retrieval %d results in %.1fms"%(len(ret["chunks"]), (timer()-st)*1000))
|
||||
await self._async_update_chunk_info(chunk_info, ret)
|
||||
ret = kb_prompt(ret, self.chat_mdl.max_length*0.5)
|
||||
|
||||
@ -111,14 +111,14 @@ class TreeStructuredQueryDecompositionRetrieval:
|
||||
suff = await sufficiency_check(self.chat_mdl, question, ret)
|
||||
if suff["is_sufficient"]:
|
||||
if callback:
|
||||
await callback("Yes, it's sufficient.")
|
||||
await callback(f"Yes, the retrieved information is sufficient for '{question}'.")
|
||||
return ret
|
||||
|
||||
#if callback:
|
||||
# await callback("The retrieved information is not sufficient. Planing next steps...")
|
||||
succ_question_info = await multi_queries_gen(self.chat_mdl, question, query, suff["missing_information"], ret)
|
||||
if callback:
|
||||
await callback("Next step is to search for the following questions:\n" + "\n - ".join(step["question"] for step in succ_question_info["questions"]))
|
||||
await callback("Next step is to search for the following questions:</br> - " + "</br> - ".join(step["question"] for step in succ_question_info["questions"]))
|
||||
steps = []
|
||||
for step in succ_question_info["questions"]:
|
||||
steps.append(asyncio.create_task(self._research(chunk_info, step["question"], step["query"], depth-1, callback)))
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
@ -52,8 +53,8 @@ class Benchmark:
|
||||
run = defaultdict(dict)
|
||||
query_list = list(qrels.keys())
|
||||
for query in query_list:
|
||||
ranks = settings.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
|
||||
0.0, self.vector_similarity_weight)
|
||||
ranks = asyncio.run(settings.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
|
||||
0.0, self.vector_similarity_weight))
|
||||
if len(ranks["chunks"]) == 0:
|
||||
print(f"deleted query: {query}")
|
||||
del qrels[query]
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
@ -49,8 +50,8 @@ class Dealer:
|
||||
keywords: list[str] | None = None
|
||||
group_docs: list[list] | None = None
|
||||
|
||||
def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1):
|
||||
qv, _ = emb_mdl.encode_queries(txt)
|
||||
async def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1):
|
||||
qv, _ = await asyncio.to_thread(emb_mdl.encode_queries, txt)
|
||||
shape = np.array(qv).shape
|
||||
if len(shape) > 1:
|
||||
raise Exception(
|
||||
@ -71,7 +72,7 @@ class Dealer:
|
||||
condition[key] = req[key]
|
||||
return condition
|
||||
|
||||
def search(self, req, idx_names: str | list[str],
|
||||
async def search(self, req, idx_names: str | list[str],
|
||||
kb_ids: list[str],
|
||||
emb_mdl=None,
|
||||
highlight: bool | list | None = None,
|
||||
@ -114,12 +115,12 @@ class Dealer:
|
||||
matchText, keywords = self.qryr.question(qst, min_match=0.3)
|
||||
if emb_mdl is None:
|
||||
matchExprs = [matchText]
|
||||
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
||||
res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
||||
idx_names, kb_ids, rank_feature=rank_feature)
|
||||
total = self.dataStore.get_total(res)
|
||||
logging.debug("Dealer.search TOTAL: {}".format(total))
|
||||
else:
|
||||
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
|
||||
matchDense = await self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
|
||||
q_vec = matchDense.embedding_data
|
||||
if not settings.DOC_ENGINE_INFINITY:
|
||||
src.append(f"q_{len(q_vec)}_vec")
|
||||
@ -127,7 +128,7 @@ class Dealer:
|
||||
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05,0.95"})
|
||||
matchExprs = [matchText, matchDense, fusionExpr]
|
||||
|
||||
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
||||
res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
||||
idx_names, kb_ids, rank_feature=rank_feature)
|
||||
total = self.dataStore.get_total(res)
|
||||
logging.debug("Dealer.search TOTAL: {}".format(total))
|
||||
@ -135,12 +136,12 @@ class Dealer:
|
||||
# If result is empty, try again with lower min_match
|
||||
if total == 0:
|
||||
if filters.get("doc_id"):
|
||||
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
|
||||
res = await asyncio.to_thread(self.dataStore.search, src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
|
||||
total = self.dataStore.get_total(res)
|
||||
else:
|
||||
matchText, _ = self.qryr.question(qst, min_match=0.1)
|
||||
matchDense.extra_options["similarity"] = 0.17
|
||||
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
|
||||
res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, [matchText, matchDense, fusionExpr],
|
||||
orderBy, offset, limit, idx_names, kb_ids,
|
||||
rank_feature=rank_feature)
|
||||
total = self.dataStore.get_total(res)
|
||||
@ -359,7 +360,7 @@ class Dealer:
|
||||
rag_tokenizer.tokenize(ans).split(),
|
||||
rag_tokenizer.tokenize(inst).split())
|
||||
|
||||
def retrieval(
|
||||
async def retrieval(
|
||||
self,
|
||||
question,
|
||||
embd_mdl,
|
||||
@ -398,7 +399,7 @@ class Dealer:
|
||||
if isinstance(tenant_ids, str):
|
||||
tenant_ids = tenant_ids.split(",")
|
||||
|
||||
sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight,
|
||||
sres = await self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight,
|
||||
rank_feature=rank_feature)
|
||||
|
||||
if rerank_mdl and sres.total > 0:
|
||||
|
||||
Reference in New Issue
Block a user