mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Feat:update check_embedding api (#11254)
### What problem does this PR solve? pr: #10854 change: update check_embedding api ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -442,7 +442,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
tk_count = 0
|
||||
if len(tts) == len(cnts):
|
||||
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1]))
|
||||
tts = np.concatenate([vts[0] for _ in range(len(tts))], axis=0)
|
||||
tts = np.tile(vts[0], (len(cnts), 1))
|
||||
tk_count += c
|
||||
|
||||
@timeout(60)
|
||||
@ -465,8 +465,10 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
if not filename_embd_weight:
|
||||
filename_embd_weight = 0.1
|
||||
title_w = float(filename_embd_weight)
|
||||
vects = (title_w * tts + (1 - title_w) *
|
||||
cnts) if len(tts) == len(cnts) else cnts
|
||||
if tts.ndim == 2 and cnts.ndim == 2 and tts.shape == cnts.shape:
|
||||
vects = title_w * tts + (1 - title_w) * cnts
|
||||
else:
|
||||
vects = cnts
|
||||
|
||||
assert len(vects) == len(docs)
|
||||
vector_size = 0
|
||||
|
||||
Reference in New Issue
Block a user