Refactor: Use Input Length In DefaultRerank (#9516)

### What problem does this PR solve?

1. Use input length to prepare res
2. Adjust torch_empty_cache code location

### Type of change

- [x] Refactoring
- [x] Performance Improvement
This commit is contained in:
Stephen Hu
2025-08-18 10:00:27 +08:00
committed by GitHub
parent d874683ae4
commit fb77f9917b

View File

@ -100,7 +100,7 @@ class DefaultRerank(Base):
old_dynamic_batch_size = self._dynamic_batch_size old_dynamic_batch_size = self._dynamic_batch_size
if max_batch_size is not None: if max_batch_size is not None:
self._dynamic_batch_size = max_batch_size self._dynamic_batch_size = max_batch_size
res = np.array([], dtype=float) res = np.array(len(pairs), dtype=float)
i = 0 i = 0
while i < len(pairs): while i < len(pairs):
cur_i = i cur_i = i
@ -111,7 +111,7 @@ class DefaultRerank(Base):
try: try:
# call subclass implemented batch processing calculation # call subclass implemented batch processing calculation
batch_scores = self._compute_batch_scores(pairs[i : i + current_batch]) batch_scores = self._compute_batch_scores(pairs[i : i + current_batch])
res = np.append(res, batch_scores) res[i : i + current_batch] = batch_scores
i += current_batch i += current_batch
self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8) self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8)
break break
@ -125,8 +125,8 @@ class DefaultRerank(Base):
raise raise
if retry_count >= max_retries: if retry_count >= max_retries:
raise RuntimeError("max retry times, still cannot process batch, please check your GPU memory") raise RuntimeError("max retry times, still cannot process batch, please check your GPU memory")
self.torch_empty_cache()
self.torch_empty_cache()
self._dynamic_batch_size = old_dynamic_batch_size self._dynamic_batch_size = old_dynamic_batch_size
return np.array(res) return np.array(res)