Refa: cleanup synchronous functions in chat_model and implement synchronization for conversation and dialog chats (#11779)

### What problem does this PR solve?

Cleanup synchronous functions in chat_model and implement
synchronization for conversation and dialog chats.

### Type of change

- [x] Refactoring
- [x] Performance Improvement
This commit is contained in:
Yongteng Lei
2025-12-08 09:43:03 +08:00
committed by GitHub
parent 9b8971a9de
commit 51ec708c58
10 changed files with 421 additions and 843 deletions

View File

@ -19,7 +19,7 @@ from common.constants import StatusEnum
from api.db.db_models import Conversation, DB
from api.db.services.api_service import API4ConversationService
from api.db.services.common_service import CommonService
from api.db.services.dialog_service import DialogService, chat
from api.db.services.dialog_service import DialogService, async_chat
from common.misc_utils import get_uuid
import json
@ -89,8 +89,7 @@ def structure_answer(conv, ans, message_id, session_id):
conv.reference[-1] = reference
return ans
def completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
async def async_completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
assert name, "`name` can not be empty."
dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
assert dia, "You do not own the chat."
@ -112,7 +111,7 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None
"reference": {},
"audio_binary": None,
"id": None,
"session_id": session_id
"session_id": session_id
}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
@ -148,7 +147,7 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None
if stream:
try:
for ans in chat(dia, msg, True, **kwargs):
async for ans in async_chat(dia, msg, True, **kwargs):
ans = structure_answer(conv, ans, message_id, session_id)
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
ConversationService.update_by_id(conv.id, conv.to_dict())
@ -160,14 +159,13 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None
else:
answer = None
for ans in chat(dia, msg, False, **kwargs):
async for ans in async_chat(dia, msg, False, **kwargs):
answer = structure_answer(conv, ans, message_id, session_id)
ConversationService.update_by_id(conv.id, conv.to_dict())
break
yield answer
def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs):
async def async_iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs):
e, dia = DialogService.get_by_id(dialog_id)
assert e, "Dialog not found"
if not session_id:
@ -222,7 +220,7 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg
if stream:
try:
for ans in chat(dia, msg, True, **kwargs):
async for ans in async_chat(dia, msg, True, **kwargs):
ans = structure_answer(conv, ans, message_id, session_id)
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
ensure_ascii=False) + "\n\n"
@ -235,7 +233,7 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg
else:
answer = None
for ans in chat(dia, msg, False, **kwargs):
async for ans in async_chat(dia, msg, False, **kwargs):
answer = structure_answer(conv, ans, message_id, session_id)
API4ConversationService.append_message(conv.id, conv.to_dict())
break

View File

@ -178,7 +178,8 @@ class DialogService(CommonService):
offset += limit
return res
def chat_solo(dialog, messages, stream=True):
async def async_chat_solo(dialog, messages, stream=True):
attachments = ""
if "files" in messages[-1]:
attachments = "\n\n".join(FileService.get_files(messages[-1]["files"]))
@ -197,7 +198,8 @@ def chat_solo(dialog, messages, stream=True):
if stream:
last_ans = ""
delta_ans = ""
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
answer = ""
async for ans in chat_mdl.async_chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
answer = ans
delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 16:
@ -208,7 +210,7 @@ def chat_solo(dialog, messages, stream=True):
if delta_ans:
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
else:
answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
user_content = msg[-1].get("content", "[content not available]")
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
@ -347,13 +349,12 @@ def meta_filter(metas: dict, filters: list[dict], logic: str = "and"):
return []
return list(doc_ids)
def chat(dialog, messages, stream=True, **kwargs):
async def async_chat(dialog, messages, stream=True, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
for ans in chat_solo(dialog, messages, stream):
async for ans in async_chat_solo(dialog, messages, stream):
yield ans
return None
return
chat_start_ts = timer()
@ -400,7 +401,7 @@ def chat(dialog, messages, stream=True, **kwargs):
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
if ans:
yield ans
return None
return
for p in prompt_config["parameters"]:
if p["key"] == "knowledge":
@ -508,7 +509,8 @@ def chat(dialog, messages, stream=True, **kwargs):
empty_res = prompt_config["empty_response"]
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
"audio_binary": tts(tts_mdl, empty_res)}
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
return
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
gen_conf = dialog.llm_setting
@ -612,7 +614,7 @@ def chat(dialog, messages, stream=True, **kwargs):
if stream:
last_ans = ""
answer = ""
for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
async for ans in chat_mdl.async_chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
if thought:
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
answer = ans
@ -626,19 +628,19 @@ def chat(dialog, messages, stream=True, **kwargs):
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
yield decorate_answer(thought + answer)
else:
answer = chat_mdl.chat(prompt + prompt4citation, msg[1:], gen_conf)
answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf)
user_content = msg[-1].get("content", "[content not available]")
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
res = decorate_answer(answer)
res["audio_binary"] = tts(tts_mdl, answer)
yield res
return None
return
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
sys_prompt = """
You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question.
You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question.
Ensure that:
1. Field names should not start with a digit. If any field name starts with a digit, use double quotes around it.
2. Write only the SQL, no explanations or additional text.
@ -805,8 +807,7 @@ def tts(tts_mdl, text):
return None
return binascii.hexlify(bin).decode("utf-8")
def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
doc_ids = search_config.get("doc_ids", [])
rerank_mdl = None
kb_ids = search_config.get("kb_ids", kb_ids)
@ -880,7 +881,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
return {"answer": answer, "reference": refs}
answer = ""
for ans in chat_mdl.chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
async for ans in chat_mdl.async_chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
answer = ans
yield {"answer": answer, "reference": {}}
yield decorate_answer(answer)

View File

@ -25,14 +25,17 @@ Provides functionality for evaluating RAG system performance including:
- Configuration recommendations
"""
import asyncio
import logging
import queue
import threading
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from timeit import default_timer as timer
from api.db.db_models import EvaluationDataset, EvaluationCase, EvaluationRun, EvaluationResult
from api.db.services.common_service import CommonService
from api.db.services.dialog_service import DialogService, chat
from api.db.services.dialog_service import DialogService
from common.misc_utils import get_uuid
from common.time_utils import current_timestamp
from common.constants import StatusEnum
@ -40,24 +43,24 @@ from common.constants import StatusEnum
class EvaluationService(CommonService):
"""Service for managing RAG evaluations"""
model = EvaluationDataset
# ==================== Dataset Management ====================
@classmethod
def create_dataset(cls, name: str, description: str, kb_ids: List[str],
def create_dataset(cls, name: str, description: str, kb_ids: List[str],
tenant_id: str, user_id: str) -> Tuple[bool, str]:
"""
Create a new evaluation dataset.
Args:
name: Dataset name
description: Dataset description
kb_ids: List of knowledge base IDs to evaluate against
tenant_id: Tenant ID
user_id: User ID who creates the dataset
Returns:
(success, dataset_id or error_message)
"""
@ -74,15 +77,15 @@ class EvaluationService(CommonService):
"update_time": current_timestamp(),
"status": StatusEnum.VALID.value
}
if not EvaluationDataset.create(**dataset):
return False, "Failed to create dataset"
return True, dataset_id
except Exception as e:
logging.error(f"Error creating evaluation dataset: {e}")
return False, str(e)
@classmethod
def get_dataset(cls, dataset_id: str) -> Optional[Dict[str, Any]]:
"""Get dataset by ID"""
@ -94,9 +97,9 @@ class EvaluationService(CommonService):
except Exception as e:
logging.error(f"Error getting dataset {dataset_id}: {e}")
return None
@classmethod
def list_datasets(cls, tenant_id: str, user_id: str,
def list_datasets(cls, tenant_id: str, user_id: str,
page: int = 1, page_size: int = 20) -> Dict[str, Any]:
"""List datasets for a tenant"""
try:
@ -104,10 +107,10 @@ class EvaluationService(CommonService):
(EvaluationDataset.tenant_id == tenant_id) &
(EvaluationDataset.status == StatusEnum.VALID.value)
).order_by(EvaluationDataset.create_time.desc())
total = query.count()
datasets = query.paginate(page, page_size)
return {
"total": total,
"datasets": [d.to_dict() for d in datasets]
@ -115,7 +118,7 @@ class EvaluationService(CommonService):
except Exception as e:
logging.error(f"Error listing datasets: {e}")
return {"total": 0, "datasets": []}
@classmethod
def update_dataset(cls, dataset_id: str, **kwargs) -> bool:
"""Update dataset"""
@ -127,7 +130,7 @@ class EvaluationService(CommonService):
except Exception as e:
logging.error(f"Error updating dataset {dataset_id}: {e}")
return False
@classmethod
def delete_dataset(cls, dataset_id: str) -> bool:
"""Soft delete dataset"""
@ -139,18 +142,18 @@ class EvaluationService(CommonService):
except Exception as e:
logging.error(f"Error deleting dataset {dataset_id}: {e}")
return False
# ==================== Test Case Management ====================
@classmethod
def add_test_case(cls, dataset_id: str, question: str,
def add_test_case(cls, dataset_id: str, question: str,
reference_answer: Optional[str] = None,
relevant_doc_ids: Optional[List[str]] = None,
relevant_chunk_ids: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None) -> Tuple[bool, str]:
"""
Add a test case to a dataset.
Args:
dataset_id: Dataset ID
question: Test question
@ -158,7 +161,7 @@ class EvaluationService(CommonService):
relevant_doc_ids: Optional list of relevant document IDs
relevant_chunk_ids: Optional list of relevant chunk IDs
metadata: Optional additional metadata
Returns:
(success, case_id or error_message)
"""
@ -174,15 +177,15 @@ class EvaluationService(CommonService):
"metadata": metadata,
"create_time": current_timestamp()
}
if not EvaluationCase.create(**case):
return False, "Failed to create test case"
return True, case_id
except Exception as e:
logging.error(f"Error adding test case: {e}")
return False, str(e)
@classmethod
def get_test_cases(cls, dataset_id: str) -> List[Dict[str, Any]]:
"""Get all test cases for a dataset"""
@ -190,12 +193,12 @@ class EvaluationService(CommonService):
cases = EvaluationCase.select().where(
EvaluationCase.dataset_id == dataset_id
).order_by(EvaluationCase.create_time)
return [c.to_dict() for c in cases]
except Exception as e:
logging.error(f"Error getting test cases for dataset {dataset_id}: {e}")
return []
@classmethod
def delete_test_case(cls, case_id: str) -> bool:
"""Delete a test case"""
@ -206,22 +209,22 @@ class EvaluationService(CommonService):
except Exception as e:
logging.error(f"Error deleting test case {case_id}: {e}")
return False
@classmethod
def import_test_cases(cls, dataset_id: str, cases: List[Dict[str, Any]]) -> Tuple[int, int]:
"""
Bulk import test cases from a list.
Args:
dataset_id: Dataset ID
cases: List of test case dictionaries
Returns:
(success_count, failure_count)
"""
success_count = 0
failure_count = 0
for case_data in cases:
success, _ = cls.add_test_case(
dataset_id=dataset_id,
@ -231,28 +234,28 @@ class EvaluationService(CommonService):
relevant_chunk_ids=case_data.get("relevant_chunk_ids"),
metadata=case_data.get("metadata")
)
if success:
success_count += 1
else:
failure_count += 1
return success_count, failure_count
# ==================== Evaluation Execution ====================
@classmethod
def start_evaluation(cls, dataset_id: str, dialog_id: str,
def start_evaluation(cls, dataset_id: str, dialog_id: str,
user_id: str, name: Optional[str] = None) -> Tuple[bool, str]:
"""
Start an evaluation run.
Args:
dataset_id: Dataset ID
dialog_id: Dialog configuration to evaluate
user_id: User ID who starts the run
name: Optional run name
Returns:
(success, run_id or error_message)
"""
@ -261,12 +264,12 @@ class EvaluationService(CommonService):
success, dialog = DialogService.get_by_id(dialog_id)
if not success:
return False, "Dialog not found"
# Create evaluation run
run_id = get_uuid()
if not name:
name = f"Evaluation Run {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
run = {
"id": run_id,
"dataset_id": dataset_id,
@ -279,92 +282,128 @@ class EvaluationService(CommonService):
"create_time": current_timestamp(),
"complete_time": None
}
if not EvaluationRun.create(**run):
return False, "Failed to create evaluation run"
# Execute evaluation asynchronously (in production, use task queue)
# For now, we'll execute synchronously
cls._execute_evaluation(run_id, dataset_id, dialog)
return True, run_id
except Exception as e:
logging.error(f"Error starting evaluation: {e}")
return False, str(e)
@classmethod
def _execute_evaluation(cls, run_id: str, dataset_id: str, dialog: Any):
"""
Execute evaluation for all test cases.
This method runs the RAG pipeline for each test case and computes metrics.
"""
try:
# Get all test cases
test_cases = cls.get_test_cases(dataset_id)
if not test_cases:
EvaluationRun.update(
status="FAILED",
complete_time=current_timestamp()
).where(EvaluationRun.id == run_id).execute()
return
# Execute each test case
results = []
for case in test_cases:
result = cls._evaluate_single_case(run_id, case, dialog)
if result:
results.append(result)
# Compute summary metrics
metrics_summary = cls._compute_summary_metrics(results)
# Update run status
EvaluationRun.update(
status="COMPLETED",
metrics_summary=metrics_summary,
complete_time=current_timestamp()
).where(EvaluationRun.id == run_id).execute()
except Exception as e:
logging.error(f"Error executing evaluation {run_id}: {e}")
EvaluationRun.update(
status="FAILED",
complete_time=current_timestamp()
).where(EvaluationRun.id == run_id).execute()
@classmethod
def _evaluate_single_case(cls, run_id: str, case: Dict[str, Any],
def _evaluate_single_case(cls, run_id: str, case: Dict[str, Any],
dialog: Any) -> Optional[Dict[str, Any]]:
"""
Evaluate a single test case.
Args:
run_id: Evaluation run ID
case: Test case dictionary
dialog: Dialog configuration
Returns:
Result dictionary or None if failed
"""
try:
# Prepare messages
messages = [{"role": "user", "content": case["question"]}]
# Execute RAG pipeline
start_time = timer()
answer = ""
retrieved_chunks = []
def _sync_from_async_gen(async_gen):
result_queue: queue.Queue = queue.Queue()
def runner():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
async def consume():
try:
async for item in async_gen:
result_queue.put(item)
except Exception as e:
result_queue.put(e)
finally:
result_queue.put(StopIteration)
loop.run_until_complete(consume())
loop.close()
threading.Thread(target=runner, daemon=True).start()
while True:
item = result_queue.get()
if item is StopIteration:
break
if isinstance(item, Exception):
raise item
yield item
def chat(dialog, messages, stream=True, **kwargs):
from api.db.services.dialog_service import async_chat
return _sync_from_async_gen(async_chat(dialog, messages, stream=stream, **kwargs))
for ans in chat(dialog, messages, stream=False):
if isinstance(ans, dict):
answer = ans.get("answer", "")
retrieved_chunks = ans.get("reference", {}).get("chunks", [])
break
execution_time = timer() - start_time
# Compute metrics
metrics = cls._compute_metrics(
question=case["question"],
@ -374,7 +413,7 @@ class EvaluationService(CommonService):
relevant_chunk_ids=case.get("relevant_chunk_ids"),
dialog=dialog
)
# Save result
result_id = get_uuid()
result = {
@ -388,14 +427,14 @@ class EvaluationService(CommonService):
"token_usage": None, # TODO: Track token usage
"create_time": current_timestamp()
}
EvaluationResult.create(**result)
return result
except Exception as e:
logging.error(f"Error evaluating case {case.get('id')}: {e}")
return None
@classmethod
def _compute_metrics(cls, question: str, generated_answer: str,
reference_answer: Optional[str],
@ -404,69 +443,69 @@ class EvaluationService(CommonService):
dialog: Any) -> Dict[str, float]:
"""
Compute evaluation metrics for a single test case.
Returns:
Dictionary of metric names to values
"""
metrics = {}
# Retrieval metrics (if ground truth chunks provided)
if relevant_chunk_ids:
retrieved_ids = [c.get("chunk_id") for c in retrieved_chunks]
metrics.update(cls._compute_retrieval_metrics(retrieved_ids, relevant_chunk_ids))
# Generation metrics
if generated_answer:
# Basic metrics
metrics["answer_length"] = len(generated_answer)
metrics["has_answer"] = 1.0 if generated_answer.strip() else 0.0
# TODO: Implement advanced metrics using LLM-as-judge
# - Faithfulness (hallucination detection)
# - Answer relevance
# - Context relevance
# - Semantic similarity (if reference answer provided)
return metrics
@classmethod
def _compute_retrieval_metrics(cls, retrieved_ids: List[str],
def _compute_retrieval_metrics(cls, retrieved_ids: List[str],
relevant_ids: List[str]) -> Dict[str, float]:
"""
Compute retrieval metrics.
Args:
retrieved_ids: List of retrieved chunk IDs
relevant_ids: List of relevant chunk IDs (ground truth)
Returns:
Dictionary of retrieval metrics
"""
if not relevant_ids:
return {}
retrieved_set = set(retrieved_ids)
relevant_set = set(relevant_ids)
# Precision: proportion of retrieved that are relevant
precision = len(retrieved_set & relevant_set) / len(retrieved_set) if retrieved_set else 0.0
# Recall: proportion of relevant that were retrieved
recall = len(retrieved_set & relevant_set) / len(relevant_set) if relevant_set else 0.0
# F1 score
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
# Hit rate: whether any relevant chunk was retrieved
hit_rate = 1.0 if (retrieved_set & relevant_set) else 0.0
# MRR (Mean Reciprocal Rank): position of first relevant chunk
mrr = 0.0
for i, chunk_id in enumerate(retrieved_ids, 1):
if chunk_id in relevant_set:
mrr = 1.0 / i
break
return {
"precision": precision,
"recall": recall,
@ -474,45 +513,45 @@ class EvaluationService(CommonService):
"hit_rate": hit_rate,
"mrr": mrr
}
@classmethod
def _compute_summary_metrics(cls, results: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Compute summary metrics across all test cases.
Args:
results: List of result dictionaries
Returns:
Summary metrics dictionary
"""
if not results:
return {}
# Aggregate metrics
metric_sums = {}
metric_counts = {}
for result in results:
metrics = result.get("metrics", {})
for key, value in metrics.items():
if isinstance(value, (int, float)):
metric_sums[key] = metric_sums.get(key, 0) + value
metric_counts[key] = metric_counts.get(key, 0) + 1
# Compute averages
summary = {
"total_cases": len(results),
"avg_execution_time": sum(r.get("execution_time", 0) for r in results) / len(results)
}
for key in metric_sums:
summary[f"avg_{key}"] = metric_sums[key] / metric_counts[key]
return summary
# ==================== Results & Analysis ====================
@classmethod
def get_run_results(cls, run_id: str) -> Dict[str, Any]:
"""Get results for an evaluation run"""
@ -520,11 +559,11 @@ class EvaluationService(CommonService):
run = EvaluationRun.get_by_id(run_id)
if not run:
return {}
results = EvaluationResult.select().where(
EvaluationResult.run_id == run_id
).order_by(EvaluationResult.create_time)
return {
"run": run.to_dict(),
"results": [r.to_dict() for r in results]
@ -532,15 +571,15 @@ class EvaluationService(CommonService):
except Exception as e:
logging.error(f"Error getting run results {run_id}: {e}")
return {}
@classmethod
def get_recommendations(cls, run_id: str) -> List[Dict[str, Any]]:
"""
Analyze evaluation results and provide configuration recommendations.
Args:
run_id: Evaluation run ID
Returns:
List of recommendation dictionaries
"""
@ -548,10 +587,10 @@ class EvaluationService(CommonService):
run = EvaluationRun.get_by_id(run_id)
if not run or not run.metrics_summary:
return []
metrics = run.metrics_summary
recommendations = []
# Low precision: retrieving irrelevant chunks
if metrics.get("avg_precision", 1.0) < 0.7:
recommendations.append({
@ -564,7 +603,7 @@ class EvaluationService(CommonService):
"Reduce top_k to return fewer chunks"
]
})
# Low recall: missing relevant chunks
if metrics.get("avg_recall", 1.0) < 0.7:
recommendations.append({
@ -578,7 +617,7 @@ class EvaluationService(CommonService):
"Check chunk size - may be too large or too small"
]
})
# Slow response time
if metrics.get("avg_execution_time", 0) > 5.0:
recommendations.append({
@ -591,7 +630,7 @@ class EvaluationService(CommonService):
"Consider caching frequently asked questions"
]
})
return recommendations
except Exception as e:
logging.error(f"Error generating recommendations for run {run_id}: {e}")

View File

@ -16,15 +16,17 @@
import asyncio
import inspect
import logging
import queue
import re
import threading
from common.token_utils import num_tokens_from_string
from functools import partial
from typing import Generator
from common.constants import LLMType
from api.db.db_models import LLM
from api.db.services.common_service import CommonService
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
from common.constants import LLMType
from common.token_utils import num_tokens_from_string
class LLMService(CommonService):
@ -33,6 +35,7 @@ class LLMService(CommonService):
def get_init_tenant_llm(user_id):
from common import settings
tenant_llm = []
model_configs = {
@ -193,7 +196,7 @@ class LLMBundle(LLM4Tenant):
generation = self.langfuse.start_generation(
trace_context=self.trace_context,
name="stream_transcription",
metadata={"model": self.llm_name}
metadata={"model": self.llm_name},
)
final_text = ""
used_tokens = 0
@ -217,32 +220,34 @@ class LLMBundle(LLM4Tenant):
if self.langfuse:
generation.update(
output={"output": final_text},
usage_details={"total_tokens": used_tokens}
usage_details={"total_tokens": used_tokens},
)
generation.end()
return
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="stream_transcription", metadata={"model": self.llm_name})
full_text, used_tokens = mdl.transcription(audio)
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens
):
logging.error(
f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}"
generation = self.langfuse.start_generation(
trace_context=self.trace_context,
name="stream_transcription",
metadata={"model": self.llm_name},
)
full_text, used_tokens = mdl.transcription(audio)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
logging.error(f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}")
if self.langfuse:
generation.update(
output={"output": full_text},
usage_details={"total_tokens": used_tokens}
usage_details={"total_tokens": used_tokens},
)
generation.end()
yield {
"event": "final",
"text": full_text,
"streaming": False
"streaming": False,
}
def tts(self, text: str) -> Generator[bytes, None, None]:
@ -289,61 +294,79 @@ class LLMBundle(LLM4Tenant):
return kwargs
else:
return {k: v for k, v in kwargs.items() if k in allowed_params}
def _run_coroutine_sync(self, coro):
try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(coro)
result_queue: queue.Queue = queue.Queue()
def runner():
try:
result_queue.put((True, asyncio.run(coro)))
except Exception as e:
result_queue.put((False, e))
thread = threading.Thread(target=runner, daemon=True)
thread.start()
thread.join()
success, value = result_queue.get_nowait()
if success:
return value
raise value
def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str:
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
return self._run_coroutine_sync(self.async_chat(system, history, gen_conf, **kwargs))
chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
if self.is_tools and self.mdl.is_tools:
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
def _sync_from_async_stream(self, async_gen_fn, *args, **kwargs):
result_queue: queue.Queue = queue.Queue()
use_kwargs = self._clean_param(chat_partial, **kwargs)
txt, used_tokens = chat_partial(**use_kwargs)
txt = self._remove_reasoning_content(txt)
def runner():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
async def consume():
try:
async for item in async_gen_fn(*args, **kwargs):
result_queue.put(item)
except Exception as e:
result_queue.put(e)
finally:
result_queue.put(StopIteration)
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
loop.run_until_complete(consume())
loop.close()
if self.langfuse:
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
generation.end()
threading.Thread(target=runner, daemon=True).start()
return txt
while True:
item = result_queue.get()
if item is StopIteration:
break
if isinstance(item, Exception):
raise item
yield item
def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
ans = ""
chat_partial = partial(self.mdl.chat_streamly, system, history, gen_conf)
total_tokens = 0
if self.is_tools and self.mdl.is_tools:
chat_partial = partial(self.mdl.chat_streamly_with_tools, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
for txt in chat_partial(**use_kwargs):
for txt in self._sync_from_async_stream(self.async_chat_streamly, system, history, gen_conf, **kwargs):
if isinstance(txt, int):
total_tokens = txt
if self.langfuse:
generation.update(output={"output": ans})
generation.end()
break
if txt.endswith("</think>"):
ans = ans[: -len("</think>")]
ans = txt[: -len("</think>")]
continue
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
ans += txt
# cancatination has beend done in async_chat_streamly
ans = txt
yield ans
if total_tokens > 0:
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
def _bridge_sync_stream(self, gen):
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()
@ -352,7 +375,7 @@ class LLMBundle(LLM4Tenant):
try:
for item in gen:
loop.call_soon_threadsafe(queue.put_nowait, item)
except Exception as e: # pragma: no cover
except Exception as e:
loop.call_soon_threadsafe(queue.put_nowait, e)
finally:
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
@ -361,18 +384,27 @@ class LLMBundle(LLM4Tenant):
return queue
async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
if self.is_tools and self.mdl.is_tools and hasattr(self.mdl, "chat_with_tools"):
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_with_tools"):
base_fn = self.mdl.async_chat_with_tools
elif hasattr(self.mdl, "async_chat"):
base_fn = self.mdl.async_chat
else:
raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
generation = None
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
chat_partial = partial(base_fn, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
if hasattr(self.mdl, "async_chat_with_tools") and self.is_tools and self.mdl.is_tools:
txt, used_tokens = await self.mdl.async_chat_with_tools(system, history, gen_conf, **use_kwargs)
elif hasattr(self.mdl, "async_chat"):
txt, used_tokens = await self.mdl.async_chat(system, history, gen_conf, **use_kwargs)
else:
txt, used_tokens = await asyncio.to_thread(chat_partial, **use_kwargs)
try:
txt, used_tokens = await chat_partial(**use_kwargs)
except Exception as e:
if generation:
generation.update(output={"error": str(e)})
generation.end()
raise
txt = self._remove_reasoning_content(txt)
if not self.verbose_tool_use:
@ -381,49 +413,51 @@ class LLMBundle(LLM4Tenant):
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
if generation:
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
generation.end()
return txt
async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
total_tokens = 0
ans = ""
if self.is_tools and self.mdl.is_tools:
if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_streamly_with_tools"):
stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
else:
elif hasattr(self.mdl, "async_chat_streamly"):
stream_fn = getattr(self.mdl, "async_chat_streamly", None)
else:
raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
generation = None
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
if stream_fn:
chat_partial = partial(stream_fn, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
async for txt in chat_partial(**use_kwargs):
if isinstance(txt, int):
total_tokens = txt
break
try:
async for txt in chat_partial(**use_kwargs):
if isinstance(txt, int):
total_tokens = txt
break
if txt.endswith("</think>"):
ans = ans[: -len("</think>")]
if txt.endswith("</think>"):
ans = ans[: -len("</think>")]
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
ans += txt
yield ans
ans += txt
yield ans
except Exception as e:
if generation:
generation.update(output={"error": str(e)})
generation.end()
raise
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
if generation:
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
generation.end()
return
chat_partial = partial(self.mdl.chat_streamly_with_tools if (self.is_tools and self.mdl.is_tools) else self.mdl.chat_streamly, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
queue = self._bridge_sync_stream(chat_partial(**use_kwargs))
while True:
item = await queue.get()
if item is StopAsyncIteration:
break
if isinstance(item, Exception):
raise item
if isinstance(item, int):
total_tokens = item
break
yield item
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))