Refa: async retrieval process. (#12629)

### Type of change

- [x] Refactoring
- [x] Performance Improvement
This commit is contained in:
Kevin Hu
2026-01-15 12:28:49 +08:00
committed by GitHub
parent f82628c40c
commit 9a10558f80
11 changed files with 52 additions and 57 deletions

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import json
import logging
import re
@ -49,8 +50,8 @@ class Dealer:
keywords: list[str] | None = None
group_docs: list[list] | None = None
def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1):
qv, _ = emb_mdl.encode_queries(txt)
async def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1):
qv, _ = await asyncio.to_thread(emb_mdl.encode_queries, txt)
shape = np.array(qv).shape
if len(shape) > 1:
raise Exception(
@ -71,7 +72,7 @@ class Dealer:
condition[key] = req[key]
return condition
def search(self, req, idx_names: str | list[str],
async def search(self, req, idx_names: str | list[str],
kb_ids: list[str],
emb_mdl=None,
highlight: bool | list | None = None,
@ -114,12 +115,12 @@ class Dealer:
matchText, keywords = self.qryr.question(qst, min_match=0.3)
if emb_mdl is None:
matchExprs = [matchText]
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit,
idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.get_total(res)
logging.debug("Dealer.search TOTAL: {}".format(total))
else:
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
matchDense = await self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
q_vec = matchDense.embedding_data
if not settings.DOC_ENGINE_INFINITY:
src.append(f"q_{len(q_vec)}_vec")
@ -127,7 +128,7 @@ class Dealer:
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05,0.95"})
matchExprs = [matchText, matchDense, fusionExpr]
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit,
idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.get_total(res)
logging.debug("Dealer.search TOTAL: {}".format(total))
@ -135,12 +136,12 @@ class Dealer:
# If result is empty, try again with lower min_match
if total == 0:
if filters.get("doc_id"):
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
res = await asyncio.to_thread(self.dataStore.search, src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
total = self.dataStore.get_total(res)
else:
matchText, _ = self.qryr.question(qst, min_match=0.1)
matchDense.extra_options["similarity"] = 0.17
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, [matchText, matchDense, fusionExpr],
orderBy, offset, limit, idx_names, kb_ids,
rank_feature=rank_feature)
total = self.dataStore.get_total(res)
@ -359,7 +360,7 @@ class Dealer:
rag_tokenizer.tokenize(ans).split(),
rag_tokenizer.tokenize(inst).split())
def retrieval(
async def retrieval(
self,
question,
embd_mdl,
@ -398,7 +399,7 @@ class Dealer:
if isinstance(tenant_ids, str):
tenant_ids = tenant_ids.split(",")
sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight,
sres = await self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight,
rank_feature=rank_feature)
if rerank_mdl and sres.total > 0: