mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
add dockerfile for cuda envirement. Refine table search strategy, (#123)
This commit is contained in:
@ -169,16 +169,25 @@ def init_kb(row):
|
||||
|
||||
|
||||
def embedding(docs, mdl, parser_config={}, callback=None):
|
||||
batch_size = 32
|
||||
tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [
|
||||
d["content_with_weight"] for d in docs]
|
||||
tk_count = 0
|
||||
if len(tts) == len(cnts):
|
||||
tts, c = mdl.encode(tts)
|
||||
tk_count += c
|
||||
tts_ = np.array([])
|
||||
for i in range(0, len(tts), batch_size):
|
||||
vts, c = mdl.encode(tts[i: i + batch_size])
|
||||
if len(tts_) == 0:
|
||||
tts_ = vts
|
||||
else:
|
||||
tts_ = np.concatenate((tts_, vts), axis=0)
|
||||
tk_count += c
|
||||
callback(prog=0.6 + 0.1 * (i + 1) / len(tts), msg="")
|
||||
tts = tts_
|
||||
|
||||
cnts_ = np.array([])
|
||||
for i in range(0, len(cnts), 8):
|
||||
vts, c = mdl.encode(cnts[i: i+8])
|
||||
for i in range(0, len(cnts), batch_size):
|
||||
vts, c = mdl.encode(cnts[i: i+batch_size])
|
||||
if len(cnts_) == 0: cnts_ = vts
|
||||
else: cnts_ = np.concatenate((cnts_, vts), axis=0)
|
||||
tk_count += c
|
||||
|
||||
Reference in New Issue
Block a user