Refactor embedding batch_size (#3825)

### What problem does this PR solve?

Refactor embedding batch_size. Close #3657

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring
This commit is contained in:
Zhichang Yu
2024-12-03 16:22:39 +08:00
committed by GitHub
parent 934dbc2e2b
commit 92ab7ef659
3 changed files with 160 additions and 109 deletions

View File

@ -63,16 +63,13 @@ class Benchmark:
run[query][c["chunk_id"]] = c["similarity"]
return run
def embedding(self, docs, batch_size=16):
vects = []
cnts = [d["content_with_weight"] for d in docs]
for i in range(0, len(cnts), batch_size):
vts, c = self.embd_mdl.encode(cnts[i: i + batch_size])
vects.extend(vts.tolist())
assert len(docs) == len(vects)
def embedding(self, docs):
texts = [d["content_with_weight"] for d in docs]
embeddings, _ = self.embd_mdl.encode(texts)
assert len(docs) == len(embeddings)
vector_size = 0
for i, d in enumerate(docs):
v = vects[i]
v = embeddings[i]
vector_size = len(v)
d["q_%d_vec" % len(v)] = v
return docs, vector_size