fix benchmark issue (#3324)

### What problem does this PR solve?



### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Kevin Hu
2024-11-11 10:14:30 +08:00
committed by GitHub
parent 7c486ee3f9
commit 5e5a35191e

View File

@ -30,6 +30,7 @@ from rag.utils.es_conn import ELASTICSEARCH
from ranx import evaluate from ranx import evaluate
import pandas as pd import pandas as pd
from tqdm import tqdm from tqdm import tqdm
from ranx import Qrels, Run
class Benchmark: class Benchmark:
@ -50,8 +51,8 @@ class Benchmark:
query_list = list(qrels.keys()) query_list = list(qrels.keys())
for query in query_list: for query in query_list:
ranks = retrievaler.retrieval(query, self.embd_mdl, dataset_idxnm.replace("ragflow_", ""), ranks = retrievaler.retrieval(query, self.embd_mdl,
[self.kb.id], 0, 30, dataset_idxnm, [self.kb.id], 1, 30,
0.0, self.vector_similarity_weight) 0.0, self.vector_similarity_weight)
for c in ranks["chunks"]: for c in ranks["chunks"]:
if "vector" in c: if "vector" in c:
@ -105,7 +106,9 @@ class Benchmark:
for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']): for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
d = { d = {
"id": get_uuid(), "id": get_uuid(),
"kb_id": self.kb.id "kb_id": self.kb.id,
"docnm_kwd": "xxxxx",
"doc_id": "ksksks"
} }
tokenize(d, text, "english") tokenize(d, text, "english")
docs.append(d) docs.append(d)
@ -137,7 +140,10 @@ class Benchmark:
for rel, text in zip(data.iloc[i]["search_results"]['rank'], for rel, text in zip(data.iloc[i]["search_results"]['rank'],
data.iloc[i]["search_results"]['search_context']): data.iloc[i]["search_results"]['search_context']):
d = { d = {
"id": get_uuid() "id": get_uuid(),
"kb_id": self.kb.id,
"docnm_kwd": "xxxxx",
"doc_id": "ksksks"
} }
tokenize(d, text, "english") tokenize(d, text, "english")
docs.append(d) docs.append(d)
@ -182,7 +188,10 @@ class Benchmark:
text = corpus_total[tmp_data.iloc[i]['docid']] text = corpus_total[tmp_data.iloc[i]['docid']]
rel = tmp_data.iloc[i]['relevance'] rel = tmp_data.iloc[i]['relevance']
d = { d = {
"id": get_uuid() "id": get_uuid(),
"kb_id": self.kb.id,
"docnm_kwd": "xxxxx",
"doc_id": "ksksks"
} }
tokenize(d, text, 'english') tokenize(d, text, 'english')
docs.append(d) docs.append(d)
@ -204,7 +213,7 @@ class Benchmark:
for run_i in tqdm(range(len(run_keys)), desc="Calculating ndcg@10 for single query"): for run_i in tqdm(range(len(run_keys)), desc="Calculating ndcg@10 for single query"):
key = run_keys[run_i] key = run_keys[run_i]
keep_result.append({'query': key, 'qrel': qrels[key], 'run': run[key], keep_result.append({'query': key, 'qrel': qrels[key], 'run': run[key],
'ndcg@10': evaluate({key: qrels[key]}, {key: run[key]}, "ndcg@10")}) 'ndcg@10': evaluate(Qrels({key: qrels[key]}), Run({key: run[key]}), "ndcg@10")})
keep_result = sorted(keep_result, key=lambda kk: kk['ndcg@10']) keep_result = sorted(keep_result, key=lambda kk: kk['ndcg@10'])
with open(os.path.join(file_path, dataset + 'result.md'), 'w', encoding='utf-8') as f: with open(os.path.join(file_path, dataset + 'result.md'), 'w', encoding='utf-8') as f:
f.write('## Score For Every Query\n') f.write('## Score For Every Query\n')
@ -222,12 +231,12 @@ class Benchmark:
if dataset == "ms_marco_v1.1": if dataset == "ms_marco_v1.1":
qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1") qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1")
run = self._get_retrieval(qrels, "benchmark_ms_marco_v1.1") run = self._get_retrieval(qrels, "benchmark_ms_marco_v1.1")
print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"])) print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
self.save_results(qrels, run, texts, dataset, file_path) self.save_results(qrels, run, texts, dataset, file_path)
if dataset == "trivia_qa": if dataset == "trivia_qa":
qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa") qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa")
run = self._get_retrieval(qrels, "benchmark_trivia_qa") run = self._get_retrieval(qrels, "benchmark_trivia_qa")
print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"])) print(dataset, evaluate((qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
self.save_results(qrels, run, texts, dataset, file_path) self.save_results(qrels, run, texts, dataset, file_path)
if dataset == "miracl": if dataset == "miracl":
for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th', for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th',
@ -248,7 +257,7 @@ class Benchmark:
os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang), os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang),
"benchmark_miracl_" + lang) "benchmark_miracl_" + lang)
run = self._get_retrieval(qrels, "benchmark_miracl_" + lang) run = self._get_retrieval(qrels, "benchmark_miracl_" + lang)
print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"])) print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
self.save_results(qrels, run, texts, dataset, file_path) self.save_results(qrels, run, texts, dataset, file_path)