mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refa: Remove useless conver and fix a bug for DefaultRerank (#8887)
### What problem does this PR solve? 1. bug when re-try, we need to reset i. 2. remove useless convert ### Type of change - [x] Refactoring
This commit is contained in:
@ -101,9 +101,10 @@ class DefaultRerank(Base):
|
|||||||
old_dynamic_batch_size = self._dynamic_batch_size
|
old_dynamic_batch_size = self._dynamic_batch_size
|
||||||
if max_batch_size is not None:
|
if max_batch_size is not None:
|
||||||
self._dynamic_batch_size = max_batch_size
|
self._dynamic_batch_size = max_batch_size
|
||||||
res = []
|
res = np.array([], dtype=float)
|
||||||
i = 0
|
i = 0
|
||||||
while i < len(pairs):
|
while i < len(pairs):
|
||||||
|
cur_i = i
|
||||||
current_batch = self._dynamic_batch_size
|
current_batch = self._dynamic_batch_size
|
||||||
max_retries = 5
|
max_retries = 5
|
||||||
retry_count = 0
|
retry_count = 0
|
||||||
@ -111,7 +112,7 @@ class DefaultRerank(Base):
|
|||||||
try:
|
try:
|
||||||
# call subclass implemented batch processing calculation
|
# call subclass implemented batch processing calculation
|
||||||
batch_scores = self._compute_batch_scores(pairs[i : i + current_batch])
|
batch_scores = self._compute_batch_scores(pairs[i : i + current_batch])
|
||||||
res.extend(batch_scores)
|
res = np.append(res, batch_scores)
|
||||||
i += current_batch
|
i += current_batch
|
||||||
self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8)
|
self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8)
|
||||||
break
|
break
|
||||||
@ -119,6 +120,7 @@ class DefaultRerank(Base):
|
|||||||
if "CUDA out of memory" in str(e) and current_batch > self._min_batch_size:
|
if "CUDA out of memory" in str(e) and current_batch > self._min_batch_size:
|
||||||
current_batch = max(current_batch // 2, self._min_batch_size)
|
current_batch = max(current_batch // 2, self._min_batch_size)
|
||||||
self.torch_empty_cache()
|
self.torch_empty_cache()
|
||||||
|
i = cur_i # reset i to the start of the current batch
|
||||||
retry_count += 1
|
retry_count += 1
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
@ -134,7 +136,7 @@ class DefaultRerank(Base):
|
|||||||
scores = self._model.compute_score(batch_pairs)
|
scores = self._model.compute_score(batch_pairs)
|
||||||
else:
|
else:
|
||||||
scores = self._model.compute_score(batch_pairs, max_length=max_length)
|
scores = self._model.compute_score(batch_pairs, max_length=max_length)
|
||||||
scores = sigmoid(np.array(scores)).tolist()
|
scores = sigmoid(np.array(scores))
|
||||||
if not isinstance(scores, Iterable):
|
if not isinstance(scores, Iterable):
|
||||||
scores = [scores]
|
scores = [scores]
|
||||||
return scores
|
return scores
|
||||||
|
|||||||
Reference in New Issue
Block a user