mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refa: cleanup synchronous functions in chat_model and implement synchronization for conversation and dialog chats (#11779)
### What problem does this PR solve? Cleanup synchronous functions in chat_model and implement synchronization for conversation and dialog chats. ### Type of change - [x] Refactoring - [x] Performance Improvement
This commit is contained in:
@ -25,14 +25,17 @@ Provides functionality for evaluating RAG system performance including:
|
||||
- Configuration recommendations
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from timeit import default_timer as timer
|
||||
|
||||
from api.db.db_models import EvaluationDataset, EvaluationCase, EvaluationRun, EvaluationResult
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.dialog_service import DialogService, chat
|
||||
from api.db.services.dialog_service import DialogService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.time_utils import current_timestamp
|
||||
from common.constants import StatusEnum
|
||||
@ -40,24 +43,24 @@ from common.constants import StatusEnum
|
||||
|
||||
class EvaluationService(CommonService):
|
||||
"""Service for managing RAG evaluations"""
|
||||
|
||||
|
||||
model = EvaluationDataset
|
||||
|
||||
|
||||
# ==================== Dataset Management ====================
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_dataset(cls, name: str, description: str, kb_ids: List[str],
|
||||
def create_dataset(cls, name: str, description: str, kb_ids: List[str],
|
||||
tenant_id: str, user_id: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Create a new evaluation dataset.
|
||||
|
||||
|
||||
Args:
|
||||
name: Dataset name
|
||||
description: Dataset description
|
||||
kb_ids: List of knowledge base IDs to evaluate against
|
||||
tenant_id: Tenant ID
|
||||
user_id: User ID who creates the dataset
|
||||
|
||||
|
||||
Returns:
|
||||
(success, dataset_id or error_message)
|
||||
"""
|
||||
@ -74,15 +77,15 @@ class EvaluationService(CommonService):
|
||||
"update_time": current_timestamp(),
|
||||
"status": StatusEnum.VALID.value
|
||||
}
|
||||
|
||||
|
||||
if not EvaluationDataset.create(**dataset):
|
||||
return False, "Failed to create dataset"
|
||||
|
||||
|
||||
return True, dataset_id
|
||||
except Exception as e:
|
||||
logging.error(f"Error creating evaluation dataset: {e}")
|
||||
return False, str(e)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_dataset(cls, dataset_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get dataset by ID"""
|
||||
@ -94,9 +97,9 @@ class EvaluationService(CommonService):
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting dataset {dataset_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@classmethod
|
||||
def list_datasets(cls, tenant_id: str, user_id: str,
|
||||
def list_datasets(cls, tenant_id: str, user_id: str,
|
||||
page: int = 1, page_size: int = 20) -> Dict[str, Any]:
|
||||
"""List datasets for a tenant"""
|
||||
try:
|
||||
@ -104,10 +107,10 @@ class EvaluationService(CommonService):
|
||||
(EvaluationDataset.tenant_id == tenant_id) &
|
||||
(EvaluationDataset.status == StatusEnum.VALID.value)
|
||||
).order_by(EvaluationDataset.create_time.desc())
|
||||
|
||||
|
||||
total = query.count()
|
||||
datasets = query.paginate(page, page_size)
|
||||
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"datasets": [d.to_dict() for d in datasets]
|
||||
@ -115,7 +118,7 @@ class EvaluationService(CommonService):
|
||||
except Exception as e:
|
||||
logging.error(f"Error listing datasets: {e}")
|
||||
return {"total": 0, "datasets": []}
|
||||
|
||||
|
||||
@classmethod
|
||||
def update_dataset(cls, dataset_id: str, **kwargs) -> bool:
|
||||
"""Update dataset"""
|
||||
@ -127,7 +130,7 @@ class EvaluationService(CommonService):
|
||||
except Exception as e:
|
||||
logging.error(f"Error updating dataset {dataset_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
@classmethod
|
||||
def delete_dataset(cls, dataset_id: str) -> bool:
|
||||
"""Soft delete dataset"""
|
||||
@ -139,18 +142,18 @@ class EvaluationService(CommonService):
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting dataset {dataset_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# ==================== Test Case Management ====================
|
||||
|
||||
|
||||
@classmethod
|
||||
def add_test_case(cls, dataset_id: str, question: str,
|
||||
def add_test_case(cls, dataset_id: str, question: str,
|
||||
reference_answer: Optional[str] = None,
|
||||
relevant_doc_ids: Optional[List[str]] = None,
|
||||
relevant_chunk_ids: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None) -> Tuple[bool, str]:
|
||||
"""
|
||||
Add a test case to a dataset.
|
||||
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
question: Test question
|
||||
@ -158,7 +161,7 @@ class EvaluationService(CommonService):
|
||||
relevant_doc_ids: Optional list of relevant document IDs
|
||||
relevant_chunk_ids: Optional list of relevant chunk IDs
|
||||
metadata: Optional additional metadata
|
||||
|
||||
|
||||
Returns:
|
||||
(success, case_id or error_message)
|
||||
"""
|
||||
@ -174,15 +177,15 @@ class EvaluationService(CommonService):
|
||||
"metadata": metadata,
|
||||
"create_time": current_timestamp()
|
||||
}
|
||||
|
||||
|
||||
if not EvaluationCase.create(**case):
|
||||
return False, "Failed to create test case"
|
||||
|
||||
|
||||
return True, case_id
|
||||
except Exception as e:
|
||||
logging.error(f"Error adding test case: {e}")
|
||||
return False, str(e)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_test_cases(cls, dataset_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get all test cases for a dataset"""
|
||||
@ -190,12 +193,12 @@ class EvaluationService(CommonService):
|
||||
cases = EvaluationCase.select().where(
|
||||
EvaluationCase.dataset_id == dataset_id
|
||||
).order_by(EvaluationCase.create_time)
|
||||
|
||||
|
||||
return [c.to_dict() for c in cases]
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting test cases for dataset {dataset_id}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
@classmethod
|
||||
def delete_test_case(cls, case_id: str) -> bool:
|
||||
"""Delete a test case"""
|
||||
@ -206,22 +209,22 @@ class EvaluationService(CommonService):
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting test case {case_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
@classmethod
|
||||
def import_test_cases(cls, dataset_id: str, cases: List[Dict[str, Any]]) -> Tuple[int, int]:
|
||||
"""
|
||||
Bulk import test cases from a list.
|
||||
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
cases: List of test case dictionaries
|
||||
|
||||
|
||||
Returns:
|
||||
(success_count, failure_count)
|
||||
"""
|
||||
success_count = 0
|
||||
failure_count = 0
|
||||
|
||||
|
||||
for case_data in cases:
|
||||
success, _ = cls.add_test_case(
|
||||
dataset_id=dataset_id,
|
||||
@ -231,28 +234,28 @@ class EvaluationService(CommonService):
|
||||
relevant_chunk_ids=case_data.get("relevant_chunk_ids"),
|
||||
metadata=case_data.get("metadata")
|
||||
)
|
||||
|
||||
|
||||
if success:
|
||||
success_count += 1
|
||||
else:
|
||||
failure_count += 1
|
||||
|
||||
|
||||
return success_count, failure_count
|
||||
|
||||
|
||||
# ==================== Evaluation Execution ====================
|
||||
|
||||
|
||||
@classmethod
|
||||
def start_evaluation(cls, dataset_id: str, dialog_id: str,
|
||||
def start_evaluation(cls, dataset_id: str, dialog_id: str,
|
||||
user_id: str, name: Optional[str] = None) -> Tuple[bool, str]:
|
||||
"""
|
||||
Start an evaluation run.
|
||||
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
dialog_id: Dialog configuration to evaluate
|
||||
user_id: User ID who starts the run
|
||||
name: Optional run name
|
||||
|
||||
|
||||
Returns:
|
||||
(success, run_id or error_message)
|
||||
"""
|
||||
@ -261,12 +264,12 @@ class EvaluationService(CommonService):
|
||||
success, dialog = DialogService.get_by_id(dialog_id)
|
||||
if not success:
|
||||
return False, "Dialog not found"
|
||||
|
||||
|
||||
# Create evaluation run
|
||||
run_id = get_uuid()
|
||||
if not name:
|
||||
name = f"Evaluation Run {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
|
||||
run = {
|
||||
"id": run_id,
|
||||
"dataset_id": dataset_id,
|
||||
@ -279,92 +282,128 @@ class EvaluationService(CommonService):
|
||||
"create_time": current_timestamp(),
|
||||
"complete_time": None
|
||||
}
|
||||
|
||||
|
||||
if not EvaluationRun.create(**run):
|
||||
return False, "Failed to create evaluation run"
|
||||
|
||||
|
||||
# Execute evaluation asynchronously (in production, use task queue)
|
||||
# For now, we'll execute synchronously
|
||||
cls._execute_evaluation(run_id, dataset_id, dialog)
|
||||
|
||||
|
||||
return True, run_id
|
||||
except Exception as e:
|
||||
logging.error(f"Error starting evaluation: {e}")
|
||||
return False, str(e)
|
||||
|
||||
|
||||
@classmethod
|
||||
def _execute_evaluation(cls, run_id: str, dataset_id: str, dialog: Any):
|
||||
"""
|
||||
Execute evaluation for all test cases.
|
||||
|
||||
|
||||
This method runs the RAG pipeline for each test case and computes metrics.
|
||||
"""
|
||||
try:
|
||||
# Get all test cases
|
||||
test_cases = cls.get_test_cases(dataset_id)
|
||||
|
||||
|
||||
if not test_cases:
|
||||
EvaluationRun.update(
|
||||
status="FAILED",
|
||||
complete_time=current_timestamp()
|
||||
).where(EvaluationRun.id == run_id).execute()
|
||||
return
|
||||
|
||||
|
||||
# Execute each test case
|
||||
results = []
|
||||
for case in test_cases:
|
||||
result = cls._evaluate_single_case(run_id, case, dialog)
|
||||
if result:
|
||||
results.append(result)
|
||||
|
||||
|
||||
# Compute summary metrics
|
||||
metrics_summary = cls._compute_summary_metrics(results)
|
||||
|
||||
|
||||
# Update run status
|
||||
EvaluationRun.update(
|
||||
status="COMPLETED",
|
||||
metrics_summary=metrics_summary,
|
||||
complete_time=current_timestamp()
|
||||
).where(EvaluationRun.id == run_id).execute()
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error executing evaluation {run_id}: {e}")
|
||||
EvaluationRun.update(
|
||||
status="FAILED",
|
||||
complete_time=current_timestamp()
|
||||
).where(EvaluationRun.id == run_id).execute()
|
||||
|
||||
|
||||
@classmethod
|
||||
def _evaluate_single_case(cls, run_id: str, case: Dict[str, Any],
|
||||
def _evaluate_single_case(cls, run_id: str, case: Dict[str, Any],
|
||||
dialog: Any) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Evaluate a single test case.
|
||||
|
||||
|
||||
Args:
|
||||
run_id: Evaluation run ID
|
||||
case: Test case dictionary
|
||||
dialog: Dialog configuration
|
||||
|
||||
|
||||
Returns:
|
||||
Result dictionary or None if failed
|
||||
"""
|
||||
try:
|
||||
# Prepare messages
|
||||
messages = [{"role": "user", "content": case["question"]}]
|
||||
|
||||
|
||||
# Execute RAG pipeline
|
||||
start_time = timer()
|
||||
answer = ""
|
||||
retrieved_chunks = []
|
||||
|
||||
|
||||
|
||||
def _sync_from_async_gen(async_gen):
|
||||
result_queue: queue.Queue = queue.Queue()
|
||||
|
||||
def runner():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
async def consume():
|
||||
try:
|
||||
async for item in async_gen:
|
||||
result_queue.put(item)
|
||||
except Exception as e:
|
||||
result_queue.put(e)
|
||||
finally:
|
||||
result_queue.put(StopIteration)
|
||||
|
||||
loop.run_until_complete(consume())
|
||||
loop.close()
|
||||
|
||||
threading.Thread(target=runner, daemon=True).start()
|
||||
|
||||
while True:
|
||||
item = result_queue.get()
|
||||
if item is StopIteration:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
|
||||
|
||||
def chat(dialog, messages, stream=True, **kwargs):
|
||||
from api.db.services.dialog_service import async_chat
|
||||
|
||||
return _sync_from_async_gen(async_chat(dialog, messages, stream=stream, **kwargs))
|
||||
|
||||
for ans in chat(dialog, messages, stream=False):
|
||||
if isinstance(ans, dict):
|
||||
answer = ans.get("answer", "")
|
||||
retrieved_chunks = ans.get("reference", {}).get("chunks", [])
|
||||
break
|
||||
|
||||
|
||||
execution_time = timer() - start_time
|
||||
|
||||
|
||||
# Compute metrics
|
||||
metrics = cls._compute_metrics(
|
||||
question=case["question"],
|
||||
@ -374,7 +413,7 @@ class EvaluationService(CommonService):
|
||||
relevant_chunk_ids=case.get("relevant_chunk_ids"),
|
||||
dialog=dialog
|
||||
)
|
||||
|
||||
|
||||
# Save result
|
||||
result_id = get_uuid()
|
||||
result = {
|
||||
@ -388,14 +427,14 @@ class EvaluationService(CommonService):
|
||||
"token_usage": None, # TODO: Track token usage
|
||||
"create_time": current_timestamp()
|
||||
}
|
||||
|
||||
|
||||
EvaluationResult.create(**result)
|
||||
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.error(f"Error evaluating case {case.get('id')}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@classmethod
|
||||
def _compute_metrics(cls, question: str, generated_answer: str,
|
||||
reference_answer: Optional[str],
|
||||
@ -404,69 +443,69 @@ class EvaluationService(CommonService):
|
||||
dialog: Any) -> Dict[str, float]:
|
||||
"""
|
||||
Compute evaluation metrics for a single test case.
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary of metric names to values
|
||||
"""
|
||||
metrics = {}
|
||||
|
||||
|
||||
# Retrieval metrics (if ground truth chunks provided)
|
||||
if relevant_chunk_ids:
|
||||
retrieved_ids = [c.get("chunk_id") for c in retrieved_chunks]
|
||||
metrics.update(cls._compute_retrieval_metrics(retrieved_ids, relevant_chunk_ids))
|
||||
|
||||
|
||||
# Generation metrics
|
||||
if generated_answer:
|
||||
# Basic metrics
|
||||
metrics["answer_length"] = len(generated_answer)
|
||||
metrics["has_answer"] = 1.0 if generated_answer.strip() else 0.0
|
||||
|
||||
|
||||
# TODO: Implement advanced metrics using LLM-as-judge
|
||||
# - Faithfulness (hallucination detection)
|
||||
# - Answer relevance
|
||||
# - Context relevance
|
||||
# - Semantic similarity (if reference answer provided)
|
||||
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
@classmethod
|
||||
def _compute_retrieval_metrics(cls, retrieved_ids: List[str],
|
||||
def _compute_retrieval_metrics(cls, retrieved_ids: List[str],
|
||||
relevant_ids: List[str]) -> Dict[str, float]:
|
||||
"""
|
||||
Compute retrieval metrics.
|
||||
|
||||
|
||||
Args:
|
||||
retrieved_ids: List of retrieved chunk IDs
|
||||
relevant_ids: List of relevant chunk IDs (ground truth)
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary of retrieval metrics
|
||||
"""
|
||||
if not relevant_ids:
|
||||
return {}
|
||||
|
||||
|
||||
retrieved_set = set(retrieved_ids)
|
||||
relevant_set = set(relevant_ids)
|
||||
|
||||
|
||||
# Precision: proportion of retrieved that are relevant
|
||||
precision = len(retrieved_set & relevant_set) / len(retrieved_set) if retrieved_set else 0.0
|
||||
|
||||
|
||||
# Recall: proportion of relevant that were retrieved
|
||||
recall = len(retrieved_set & relevant_set) / len(relevant_set) if relevant_set else 0.0
|
||||
|
||||
|
||||
# F1 score
|
||||
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
|
||||
|
||||
|
||||
# Hit rate: whether any relevant chunk was retrieved
|
||||
hit_rate = 1.0 if (retrieved_set & relevant_set) else 0.0
|
||||
|
||||
|
||||
# MRR (Mean Reciprocal Rank): position of first relevant chunk
|
||||
mrr = 0.0
|
||||
for i, chunk_id in enumerate(retrieved_ids, 1):
|
||||
if chunk_id in relevant_set:
|
||||
mrr = 1.0 / i
|
||||
break
|
||||
|
||||
|
||||
return {
|
||||
"precision": precision,
|
||||
"recall": recall,
|
||||
@ -474,45 +513,45 @@ class EvaluationService(CommonService):
|
||||
"hit_rate": hit_rate,
|
||||
"mrr": mrr
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def _compute_summary_metrics(cls, results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Compute summary metrics across all test cases.
|
||||
|
||||
|
||||
Args:
|
||||
results: List of result dictionaries
|
||||
|
||||
|
||||
Returns:
|
||||
Summary metrics dictionary
|
||||
"""
|
||||
if not results:
|
||||
return {}
|
||||
|
||||
|
||||
# Aggregate metrics
|
||||
metric_sums = {}
|
||||
metric_counts = {}
|
||||
|
||||
|
||||
for result in results:
|
||||
metrics = result.get("metrics", {})
|
||||
for key, value in metrics.items():
|
||||
if isinstance(value, (int, float)):
|
||||
metric_sums[key] = metric_sums.get(key, 0) + value
|
||||
metric_counts[key] = metric_counts.get(key, 0) + 1
|
||||
|
||||
|
||||
# Compute averages
|
||||
summary = {
|
||||
"total_cases": len(results),
|
||||
"avg_execution_time": sum(r.get("execution_time", 0) for r in results) / len(results)
|
||||
}
|
||||
|
||||
|
||||
for key in metric_sums:
|
||||
summary[f"avg_{key}"] = metric_sums[key] / metric_counts[key]
|
||||
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
# ==================== Results & Analysis ====================
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_run_results(cls, run_id: str) -> Dict[str, Any]:
|
||||
"""Get results for an evaluation run"""
|
||||
@ -520,11 +559,11 @@ class EvaluationService(CommonService):
|
||||
run = EvaluationRun.get_by_id(run_id)
|
||||
if not run:
|
||||
return {}
|
||||
|
||||
|
||||
results = EvaluationResult.select().where(
|
||||
EvaluationResult.run_id == run_id
|
||||
).order_by(EvaluationResult.create_time)
|
||||
|
||||
|
||||
return {
|
||||
"run": run.to_dict(),
|
||||
"results": [r.to_dict() for r in results]
|
||||
@ -532,15 +571,15 @@ class EvaluationService(CommonService):
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting run results {run_id}: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_recommendations(cls, run_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Analyze evaluation results and provide configuration recommendations.
|
||||
|
||||
|
||||
Args:
|
||||
run_id: Evaluation run ID
|
||||
|
||||
|
||||
Returns:
|
||||
List of recommendation dictionaries
|
||||
"""
|
||||
@ -548,10 +587,10 @@ class EvaluationService(CommonService):
|
||||
run = EvaluationRun.get_by_id(run_id)
|
||||
if not run or not run.metrics_summary:
|
||||
return []
|
||||
|
||||
|
||||
metrics = run.metrics_summary
|
||||
recommendations = []
|
||||
|
||||
|
||||
# Low precision: retrieving irrelevant chunks
|
||||
if metrics.get("avg_precision", 1.0) < 0.7:
|
||||
recommendations.append({
|
||||
@ -564,7 +603,7 @@ class EvaluationService(CommonService):
|
||||
"Reduce top_k to return fewer chunks"
|
||||
]
|
||||
})
|
||||
|
||||
|
||||
# Low recall: missing relevant chunks
|
||||
if metrics.get("avg_recall", 1.0) < 0.7:
|
||||
recommendations.append({
|
||||
@ -578,7 +617,7 @@ class EvaluationService(CommonService):
|
||||
"Check chunk size - may be too large or too small"
|
||||
]
|
||||
})
|
||||
|
||||
|
||||
# Slow response time
|
||||
if metrics.get("avg_execution_time", 0) > 5.0:
|
||||
recommendations.append({
|
||||
@ -591,7 +630,7 @@ class EvaluationService(CommonService):
|
||||
"Consider caching frequently asked questions"
|
||||
]
|
||||
})
|
||||
|
||||
|
||||
return recommendations
|
||||
except Exception as e:
|
||||
logging.error(f"Error generating recommendations for run {run_id}: {e}")
|
||||
|
||||
Reference in New Issue
Block a user