From 51ec708c58d2a5edaf20bf8ed35e09cfcbe291c8 Mon Sep 17 00:00:00 2001 From: Yongteng Lei Date: Mon, 8 Dec 2025 09:43:03 +0800 Subject: [PATCH] Refa: cleanup synchronous functions in chat_model and implement synchronization for conversation and dialog chats (#11779) ### What problem does this PR solve? Cleanup synchronous functions in chat_model and implement synchronization for conversation and dialog chats. ### Type of change - [x] Refactoring - [x] Performance Improvement --- api/apps/conversation_app.py | 12 +- api/apps/langfuse_app.py | 19 +- api/apps/llm_app.py | 4 +- api/apps/sdk/session.py | 25 +- api/db/services/conversation_service.py | 18 +- api/db/services/dialog_service.py | 33 +- api/db/services/evaluation_service.py | 233 ++++---- api/db/services/llm_service.py | 208 ++++--- rag/llm/__init__.py | 5 + rag/llm/chat_model.py | 707 ++++-------------------- 10 files changed, 421 insertions(+), 843 deletions(-) diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index 89630e4a4..337cb74df 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -23,7 +23,7 @@ from quart import Response, request from api.apps import current_user, login_required from api.db.db_models import APIToken from api.db.services.conversation_service import ConversationService, structure_answer -from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap +from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap from api.db.services.llm_service import LLMBundle from api.db.services.search_service import SearchService from api.db.services.tenant_llm_service import TenantLLMService @@ -218,10 +218,10 @@ async def completion(): dia.llm_setting = chat_model_config is_embedded = bool(chat_model_id) - def stream(): + async def stream(): nonlocal dia, msg, req, conv try: - for ans in chat(dia, msg, True, **req): + async for ans in async_chat(dia, msg, True, **req): ans = structure_answer(conv, ans, message_id, conv.id) yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" if not is_embedded: @@ -241,7 +241,7 @@ async def completion(): else: answer = None - for ans in chat(dia, msg, **req): + async for ans in async_chat(dia, msg, **req): answer = structure_answer(conv, ans, message_id, conv.id) if not is_embedded: ConversationService.update_by_id(conv.id, conv.to_dict()) @@ -406,10 +406,10 @@ async def ask_about(): if search_app: search_config = search_app.get("search_config", {}) - def stream(): + async def stream(): nonlocal req, uid try: - for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config): + async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config): yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" except Exception as e: yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" diff --git a/api/apps/langfuse_app.py b/api/apps/langfuse_app.py index 8a05c0d4c..1d7993d36 100644 --- a/api/apps/langfuse_app.py +++ b/api/apps/langfuse_app.py @@ -34,8 +34,9 @@ async def set_api_key(): if not all([secret_key, public_key, host]): return get_error_data_result(message="Missing required fields") + current_user_id = current_user.id langfuse_keys = dict( - tenant_id=current_user.id, + tenant_id=current_user_id, secret_key=secret_key, public_key=public_key, host=host, @@ -45,23 +46,24 @@ async def set_api_key(): if not langfuse.auth_check(): return get_error_data_result(message="Invalid Langfuse keys") - langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user.id) + langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user_id) with DB.atomic(): try: if not langfuse_entry: TenantLangfuseService.save(**langfuse_keys) else: - TenantLangfuseService.update_by_tenant(tenant_id=current_user.id, langfuse_keys=langfuse_keys) + TenantLangfuseService.update_by_tenant(tenant_id=current_user_id, langfuse_keys=langfuse_keys) return get_json_result(data=langfuse_keys) except Exception as e: - server_error_response(e) + return server_error_response(e) @manager.route("/api_key", methods=["GET"]) # noqa: F821 @login_required @validate_request() def get_api_key(): - langfuse_entry = TenantLangfuseService.filter_by_tenant_with_info(tenant_id=current_user.id) + current_user_id = current_user.id + langfuse_entry = TenantLangfuseService.filter_by_tenant_with_info(tenant_id=current_user_id) if not langfuse_entry: return get_json_result(message="Have not record any Langfuse keys.") @@ -72,7 +74,7 @@ def get_api_key(): except langfuse.api.core.api_error.ApiError as api_err: return get_json_result(message=f"Error from Langfuse: {api_err}") except Exception as e: - server_error_response(e) + return server_error_response(e) langfuse_entry["project_id"] = langfuse.api.projects.get().dict()["data"][0]["id"] langfuse_entry["project_name"] = langfuse.api.projects.get().dict()["data"][0]["name"] @@ -84,7 +86,8 @@ def get_api_key(): @login_required @validate_request() def delete_api_key(): - langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user.id) + current_user_id = current_user.id + langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user_id) if not langfuse_entry: return get_json_result(message="Have not record any Langfuse keys.") @@ -93,4 +96,4 @@ def delete_api_key(): TenantLangfuseService.delete_model(langfuse_entry) return get_json_result(data=True) except Exception as e: - server_error_response(e) + return server_error_response(e) diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 018fb4bca..d24a4bb44 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -74,7 +74,7 @@ async def set_api_key(): assert factory in ChatModel, f"Chat model from {factory} is not supported yet." mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra) try: - m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50}) + m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50}) if m.find("**ERROR**") >= 0: raise Exception(m) chat_passed = True @@ -217,7 +217,7 @@ async def add_llm(): **extra, ) try: - m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9}) + m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9}) if not tc and m.find("**ERROR**:") >= 0: raise Exception(m) except Exception as e: diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index e94f14fcc..fe4723984 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -26,9 +26,10 @@ from api.db.db_models import APIToken from api.db.services.api_service import API4ConversationService from api.db.services.canvas_service import UserCanvasService, completion_openai from api.db.services.canvas_service import completion as agent_completion -from api.db.services.conversation_service import ConversationService, iframe_completion -from api.db.services.conversation_service import completion as rag_completion -from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap, meta_filter +from api.db.services.conversation_service import ConversationService +from api.db.services.conversation_service import async_iframe_completion as iframe_completion +from api.db.services.conversation_service import async_completion as rag_completion +from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap, meta_filter from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle @@ -141,7 +142,7 @@ async def chat_completion(tenant_id, chat_id): return resp else: answer = None - for ans in rag_completion(tenant_id, chat_id, **req): + async for ans in rag_completion(tenant_id, chat_id, **req): answer = ans break return get_result(data=answer) @@ -245,7 +246,7 @@ async def chat_completion_openai_like(tenant_id, chat_id): # The value for the usage field on all chunks except for the last one will be null. # The usage field on the last chunk contains token usage statistics for the entire request. # The choices field on the last chunk will always be an empty array []. - def streamed_response_generator(chat_id, dia, msg): + async def streamed_response_generator(chat_id, dia, msg): token_used = 0 answer_cache = "" reasoning_cache = "" @@ -274,7 +275,7 @@ async def chat_completion_openai_like(tenant_id, chat_id): } try: - for ans in chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference): + async for ans in async_chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference): last_ans = ans answer = ans["answer"] @@ -342,7 +343,7 @@ async def chat_completion_openai_like(tenant_id, chat_id): return resp else: answer = None - for ans in chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference): + async for ans in async_chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference): # focus answer content only answer = ans break @@ -733,10 +734,10 @@ async def ask_about(tenant_id): return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") uid = tenant_id - def stream(): + async def stream(): nonlocal req, uid try: - for ans in ask(req["question"], req["kb_ids"], uid): + async for ans in async_ask(req["question"], req["kb_ids"], uid): yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" except Exception as e: yield "data:" + json.dumps( @@ -827,7 +828,7 @@ async def chatbot_completions(dialog_id): resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") return resp - for answer in iframe_completion(dialog_id, **req): + async for answer in iframe_completion(dialog_id, **req): return get_result(data=answer) @@ -918,10 +919,10 @@ async def ask_about_embedded(): if search_app := SearchService.get_detail(search_id): search_config = search_app.get("search_config", {}) - def stream(): + async def stream(): nonlocal req, uid try: - for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config): + async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config): yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" except Exception as e: yield "data:" + json.dumps( diff --git a/api/db/services/conversation_service.py b/api/db/services/conversation_service.py index 60f8e55b1..aaec72bf5 100644 --- a/api/db/services/conversation_service.py +++ b/api/db/services/conversation_service.py @@ -19,7 +19,7 @@ from common.constants import StatusEnum from api.db.db_models import Conversation, DB from api.db.services.api_service import API4ConversationService from api.db.services.common_service import CommonService -from api.db.services.dialog_service import DialogService, chat +from api.db.services.dialog_service import DialogService, async_chat from common.misc_utils import get_uuid import json @@ -89,8 +89,7 @@ def structure_answer(conv, ans, message_id, session_id): conv.reference[-1] = reference return ans - -def completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs): +async def async_completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs): assert name, "`name` can not be empty." dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value) assert dia, "You do not own the chat." @@ -112,7 +111,7 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None "reference": {}, "audio_binary": None, "id": None, - "session_id": session_id + "session_id": session_id }}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" @@ -148,7 +147,7 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None if stream: try: - for ans in chat(dia, msg, True, **kwargs): + async for ans in async_chat(dia, msg, True, **kwargs): ans = structure_answer(conv, ans, message_id, session_id) yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n" ConversationService.update_by_id(conv.id, conv.to_dict()) @@ -160,14 +159,13 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None else: answer = None - for ans in chat(dia, msg, False, **kwargs): + async for ans in async_chat(dia, msg, False, **kwargs): answer = structure_answer(conv, ans, message_id, session_id) ConversationService.update_by_id(conv.id, conv.to_dict()) break yield answer - -def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs): +async def async_iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs): e, dia = DialogService.get_by_id(dialog_id) assert e, "Dialog not found" if not session_id: @@ -222,7 +220,7 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg if stream: try: - for ans in chat(dia, msg, True, **kwargs): + async for ans in async_chat(dia, msg, True, **kwargs): ans = structure_answer(conv, ans, message_id, session_id) yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" @@ -235,7 +233,7 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg else: answer = None - for ans in chat(dia, msg, False, **kwargs): + async for ans in async_chat(dia, msg, False, **kwargs): answer = structure_answer(conv, ans, message_id, session_id) API4ConversationService.append_message(conv.id, conv.to_dict()) break diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 4afdd1f3c..43e345cd2 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -178,7 +178,8 @@ class DialogService(CommonService): offset += limit return res -def chat_solo(dialog, messages, stream=True): + +async def async_chat_solo(dialog, messages, stream=True): attachments = "" if "files" in messages[-1]: attachments = "\n\n".join(FileService.get_files(messages[-1]["files"])) @@ -197,7 +198,8 @@ def chat_solo(dialog, messages, stream=True): if stream: last_ans = "" delta_ans = "" - for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting): + answer = "" + async for ans in chat_mdl.async_chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting): answer = ans delta_ans = ans[len(last_ans):] if num_tokens_from_string(delta_ans) < 16: @@ -208,7 +210,7 @@ def chat_solo(dialog, messages, stream=True): if delta_ans: yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()} else: - answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting) + answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting) user_content = msg[-1].get("content", "[content not available]") logging.debug("User: {}|Assistant: {}".format(user_content, answer)) yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()} @@ -347,13 +349,12 @@ def meta_filter(metas: dict, filters: list[dict], logic: str = "and"): return [] return list(doc_ids) - -def chat(dialog, messages, stream=True, **kwargs): +async def async_chat(dialog, messages, stream=True, **kwargs): assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"): - for ans in chat_solo(dialog, messages, stream): + async for ans in async_chat_solo(dialog, messages, stream): yield ans - return None + return chat_start_ts = timer() @@ -400,7 +401,7 @@ def chat(dialog, messages, stream=True, **kwargs): ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids) if ans: yield ans - return None + return for p in prompt_config["parameters"]: if p["key"] == "knowledge": @@ -508,7 +509,8 @@ def chat(dialog, messages, stream=True, **kwargs): empty_res = prompt_config["empty_response"] yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), "audio_binary": tts(tts_mdl, empty_res)} - return {"answer": prompt_config["empty_response"], "reference": kbinfos} + yield {"answer": prompt_config["empty_response"], "reference": kbinfos} + return kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges) gen_conf = dialog.llm_setting @@ -612,7 +614,7 @@ def chat(dialog, messages, stream=True, **kwargs): if stream: last_ans = "" answer = "" - for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf): + async for ans in chat_mdl.async_chat_streamly(prompt + prompt4citation, msg[1:], gen_conf): if thought: ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) answer = ans @@ -626,19 +628,19 @@ def chat(dialog, messages, stream=True, **kwargs): yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} yield decorate_answer(thought + answer) else: - answer = chat_mdl.chat(prompt + prompt4citation, msg[1:], gen_conf) + answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf) user_content = msg[-1].get("content", "[content not available]") logging.debug("User: {}|Assistant: {}".format(user_content, answer)) res = decorate_answer(answer) res["audio_binary"] = tts(tts_mdl, answer) yield res - return None + return def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None): sys_prompt = """ -You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question. +You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question. Ensure that: 1. Field names should not start with a digit. If any field name starts with a digit, use double quotes around it. 2. Write only the SQL, no explanations or additional text. @@ -805,8 +807,7 @@ def tts(tts_mdl, text): return None return binascii.hexlify(bin).decode("utf-8") - -def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): +async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): doc_ids = search_config.get("doc_ids", []) rerank_mdl = None kb_ids = search_config.get("kb_ids", kb_ids) @@ -880,7 +881,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): return {"answer": answer, "reference": refs} answer = "" - for ans in chat_mdl.chat_streamly(sys_prompt, msg, {"temperature": 0.1}): + async for ans in chat_mdl.async_chat_streamly(sys_prompt, msg, {"temperature": 0.1}): answer = ans yield {"answer": answer, "reference": {}} yield decorate_answer(answer) diff --git a/api/db/services/evaluation_service.py b/api/db/services/evaluation_service.py index 81b4c44fe..c5a24176d 100644 --- a/api/db/services/evaluation_service.py +++ b/api/db/services/evaluation_service.py @@ -25,14 +25,17 @@ Provides functionality for evaluating RAG system performance including: - Configuration recommendations """ +import asyncio import logging +import queue +import threading from typing import List, Dict, Any, Optional, Tuple from datetime import datetime from timeit import default_timer as timer from api.db.db_models import EvaluationDataset, EvaluationCase, EvaluationRun, EvaluationResult from api.db.services.common_service import CommonService -from api.db.services.dialog_service import DialogService, chat +from api.db.services.dialog_service import DialogService from common.misc_utils import get_uuid from common.time_utils import current_timestamp from common.constants import StatusEnum @@ -40,24 +43,24 @@ from common.constants import StatusEnum class EvaluationService(CommonService): """Service for managing RAG evaluations""" - + model = EvaluationDataset - + # ==================== Dataset Management ==================== - + @classmethod - def create_dataset(cls, name: str, description: str, kb_ids: List[str], + def create_dataset(cls, name: str, description: str, kb_ids: List[str], tenant_id: str, user_id: str) -> Tuple[bool, str]: """ Create a new evaluation dataset. - + Args: name: Dataset name description: Dataset description kb_ids: List of knowledge base IDs to evaluate against tenant_id: Tenant ID user_id: User ID who creates the dataset - + Returns: (success, dataset_id or error_message) """ @@ -74,15 +77,15 @@ class EvaluationService(CommonService): "update_time": current_timestamp(), "status": StatusEnum.VALID.value } - + if not EvaluationDataset.create(**dataset): return False, "Failed to create dataset" - + return True, dataset_id except Exception as e: logging.error(f"Error creating evaluation dataset: {e}") return False, str(e) - + @classmethod def get_dataset(cls, dataset_id: str) -> Optional[Dict[str, Any]]: """Get dataset by ID""" @@ -94,9 +97,9 @@ class EvaluationService(CommonService): except Exception as e: logging.error(f"Error getting dataset {dataset_id}: {e}") return None - + @classmethod - def list_datasets(cls, tenant_id: str, user_id: str, + def list_datasets(cls, tenant_id: str, user_id: str, page: int = 1, page_size: int = 20) -> Dict[str, Any]: """List datasets for a tenant""" try: @@ -104,10 +107,10 @@ class EvaluationService(CommonService): (EvaluationDataset.tenant_id == tenant_id) & (EvaluationDataset.status == StatusEnum.VALID.value) ).order_by(EvaluationDataset.create_time.desc()) - + total = query.count() datasets = query.paginate(page, page_size) - + return { "total": total, "datasets": [d.to_dict() for d in datasets] @@ -115,7 +118,7 @@ class EvaluationService(CommonService): except Exception as e: logging.error(f"Error listing datasets: {e}") return {"total": 0, "datasets": []} - + @classmethod def update_dataset(cls, dataset_id: str, **kwargs) -> bool: """Update dataset""" @@ -127,7 +130,7 @@ class EvaluationService(CommonService): except Exception as e: logging.error(f"Error updating dataset {dataset_id}: {e}") return False - + @classmethod def delete_dataset(cls, dataset_id: str) -> bool: """Soft delete dataset""" @@ -139,18 +142,18 @@ class EvaluationService(CommonService): except Exception as e: logging.error(f"Error deleting dataset {dataset_id}: {e}") return False - + # ==================== Test Case Management ==================== - + @classmethod - def add_test_case(cls, dataset_id: str, question: str, + def add_test_case(cls, dataset_id: str, question: str, reference_answer: Optional[str] = None, relevant_doc_ids: Optional[List[str]] = None, relevant_chunk_ids: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None) -> Tuple[bool, str]: """ Add a test case to a dataset. - + Args: dataset_id: Dataset ID question: Test question @@ -158,7 +161,7 @@ class EvaluationService(CommonService): relevant_doc_ids: Optional list of relevant document IDs relevant_chunk_ids: Optional list of relevant chunk IDs metadata: Optional additional metadata - + Returns: (success, case_id or error_message) """ @@ -174,15 +177,15 @@ class EvaluationService(CommonService): "metadata": metadata, "create_time": current_timestamp() } - + if not EvaluationCase.create(**case): return False, "Failed to create test case" - + return True, case_id except Exception as e: logging.error(f"Error adding test case: {e}") return False, str(e) - + @classmethod def get_test_cases(cls, dataset_id: str) -> List[Dict[str, Any]]: """Get all test cases for a dataset""" @@ -190,12 +193,12 @@ class EvaluationService(CommonService): cases = EvaluationCase.select().where( EvaluationCase.dataset_id == dataset_id ).order_by(EvaluationCase.create_time) - + return [c.to_dict() for c in cases] except Exception as e: logging.error(f"Error getting test cases for dataset {dataset_id}: {e}") return [] - + @classmethod def delete_test_case(cls, case_id: str) -> bool: """Delete a test case""" @@ -206,22 +209,22 @@ class EvaluationService(CommonService): except Exception as e: logging.error(f"Error deleting test case {case_id}: {e}") return False - + @classmethod def import_test_cases(cls, dataset_id: str, cases: List[Dict[str, Any]]) -> Tuple[int, int]: """ Bulk import test cases from a list. - + Args: dataset_id: Dataset ID cases: List of test case dictionaries - + Returns: (success_count, failure_count) """ success_count = 0 failure_count = 0 - + for case_data in cases: success, _ = cls.add_test_case( dataset_id=dataset_id, @@ -231,28 +234,28 @@ class EvaluationService(CommonService): relevant_chunk_ids=case_data.get("relevant_chunk_ids"), metadata=case_data.get("metadata") ) - + if success: success_count += 1 else: failure_count += 1 - + return success_count, failure_count - + # ==================== Evaluation Execution ==================== - + @classmethod - def start_evaluation(cls, dataset_id: str, dialog_id: str, + def start_evaluation(cls, dataset_id: str, dialog_id: str, user_id: str, name: Optional[str] = None) -> Tuple[bool, str]: """ Start an evaluation run. - + Args: dataset_id: Dataset ID dialog_id: Dialog configuration to evaluate user_id: User ID who starts the run name: Optional run name - + Returns: (success, run_id or error_message) """ @@ -261,12 +264,12 @@ class EvaluationService(CommonService): success, dialog = DialogService.get_by_id(dialog_id) if not success: return False, "Dialog not found" - + # Create evaluation run run_id = get_uuid() if not name: name = f"Evaluation Run {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" - + run = { "id": run_id, "dataset_id": dataset_id, @@ -279,92 +282,128 @@ class EvaluationService(CommonService): "create_time": current_timestamp(), "complete_time": None } - + if not EvaluationRun.create(**run): return False, "Failed to create evaluation run" - + # Execute evaluation asynchronously (in production, use task queue) # For now, we'll execute synchronously cls._execute_evaluation(run_id, dataset_id, dialog) - + return True, run_id except Exception as e: logging.error(f"Error starting evaluation: {e}") return False, str(e) - + @classmethod def _execute_evaluation(cls, run_id: str, dataset_id: str, dialog: Any): """ Execute evaluation for all test cases. - + This method runs the RAG pipeline for each test case and computes metrics. """ try: # Get all test cases test_cases = cls.get_test_cases(dataset_id) - + if not test_cases: EvaluationRun.update( status="FAILED", complete_time=current_timestamp() ).where(EvaluationRun.id == run_id).execute() return - + # Execute each test case results = [] for case in test_cases: result = cls._evaluate_single_case(run_id, case, dialog) if result: results.append(result) - + # Compute summary metrics metrics_summary = cls._compute_summary_metrics(results) - + # Update run status EvaluationRun.update( status="COMPLETED", metrics_summary=metrics_summary, complete_time=current_timestamp() ).where(EvaluationRun.id == run_id).execute() - + except Exception as e: logging.error(f"Error executing evaluation {run_id}: {e}") EvaluationRun.update( status="FAILED", complete_time=current_timestamp() ).where(EvaluationRun.id == run_id).execute() - + @classmethod - def _evaluate_single_case(cls, run_id: str, case: Dict[str, Any], + def _evaluate_single_case(cls, run_id: str, case: Dict[str, Any], dialog: Any) -> Optional[Dict[str, Any]]: """ Evaluate a single test case. - + Args: run_id: Evaluation run ID case: Test case dictionary dialog: Dialog configuration - + Returns: Result dictionary or None if failed """ try: # Prepare messages messages = [{"role": "user", "content": case["question"]}] - + # Execute RAG pipeline start_time = timer() answer = "" retrieved_chunks = [] - + + + def _sync_from_async_gen(async_gen): + result_queue: queue.Queue = queue.Queue() + + def runner(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + async def consume(): + try: + async for item in async_gen: + result_queue.put(item) + except Exception as e: + result_queue.put(e) + finally: + result_queue.put(StopIteration) + + loop.run_until_complete(consume()) + loop.close() + + threading.Thread(target=runner, daemon=True).start() + + while True: + item = result_queue.get() + if item is StopIteration: + break + if isinstance(item, Exception): + raise item + yield item + + + def chat(dialog, messages, stream=True, **kwargs): + from api.db.services.dialog_service import async_chat + + return _sync_from_async_gen(async_chat(dialog, messages, stream=stream, **kwargs)) + for ans in chat(dialog, messages, stream=False): if isinstance(ans, dict): answer = ans.get("answer", "") retrieved_chunks = ans.get("reference", {}).get("chunks", []) break - + execution_time = timer() - start_time - + # Compute metrics metrics = cls._compute_metrics( question=case["question"], @@ -374,7 +413,7 @@ class EvaluationService(CommonService): relevant_chunk_ids=case.get("relevant_chunk_ids"), dialog=dialog ) - + # Save result result_id = get_uuid() result = { @@ -388,14 +427,14 @@ class EvaluationService(CommonService): "token_usage": None, # TODO: Track token usage "create_time": current_timestamp() } - + EvaluationResult.create(**result) - + return result except Exception as e: logging.error(f"Error evaluating case {case.get('id')}: {e}") return None - + @classmethod def _compute_metrics(cls, question: str, generated_answer: str, reference_answer: Optional[str], @@ -404,69 +443,69 @@ class EvaluationService(CommonService): dialog: Any) -> Dict[str, float]: """ Compute evaluation metrics for a single test case. - + Returns: Dictionary of metric names to values """ metrics = {} - + # Retrieval metrics (if ground truth chunks provided) if relevant_chunk_ids: retrieved_ids = [c.get("chunk_id") for c in retrieved_chunks] metrics.update(cls._compute_retrieval_metrics(retrieved_ids, relevant_chunk_ids)) - + # Generation metrics if generated_answer: # Basic metrics metrics["answer_length"] = len(generated_answer) metrics["has_answer"] = 1.0 if generated_answer.strip() else 0.0 - + # TODO: Implement advanced metrics using LLM-as-judge # - Faithfulness (hallucination detection) # - Answer relevance # - Context relevance # - Semantic similarity (if reference answer provided) - + return metrics - + @classmethod - def _compute_retrieval_metrics(cls, retrieved_ids: List[str], + def _compute_retrieval_metrics(cls, retrieved_ids: List[str], relevant_ids: List[str]) -> Dict[str, float]: """ Compute retrieval metrics. - + Args: retrieved_ids: List of retrieved chunk IDs relevant_ids: List of relevant chunk IDs (ground truth) - + Returns: Dictionary of retrieval metrics """ if not relevant_ids: return {} - + retrieved_set = set(retrieved_ids) relevant_set = set(relevant_ids) - + # Precision: proportion of retrieved that are relevant precision = len(retrieved_set & relevant_set) / len(retrieved_set) if retrieved_set else 0.0 - + # Recall: proportion of relevant that were retrieved recall = len(retrieved_set & relevant_set) / len(relevant_set) if relevant_set else 0.0 - + # F1 score f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0 - + # Hit rate: whether any relevant chunk was retrieved hit_rate = 1.0 if (retrieved_set & relevant_set) else 0.0 - + # MRR (Mean Reciprocal Rank): position of first relevant chunk mrr = 0.0 for i, chunk_id in enumerate(retrieved_ids, 1): if chunk_id in relevant_set: mrr = 1.0 / i break - + return { "precision": precision, "recall": recall, @@ -474,45 +513,45 @@ class EvaluationService(CommonService): "hit_rate": hit_rate, "mrr": mrr } - + @classmethod def _compute_summary_metrics(cls, results: List[Dict[str, Any]]) -> Dict[str, Any]: """ Compute summary metrics across all test cases. - + Args: results: List of result dictionaries - + Returns: Summary metrics dictionary """ if not results: return {} - + # Aggregate metrics metric_sums = {} metric_counts = {} - + for result in results: metrics = result.get("metrics", {}) for key, value in metrics.items(): if isinstance(value, (int, float)): metric_sums[key] = metric_sums.get(key, 0) + value metric_counts[key] = metric_counts.get(key, 0) + 1 - + # Compute averages summary = { "total_cases": len(results), "avg_execution_time": sum(r.get("execution_time", 0) for r in results) / len(results) } - + for key in metric_sums: summary[f"avg_{key}"] = metric_sums[key] / metric_counts[key] - + return summary - + # ==================== Results & Analysis ==================== - + @classmethod def get_run_results(cls, run_id: str) -> Dict[str, Any]: """Get results for an evaluation run""" @@ -520,11 +559,11 @@ class EvaluationService(CommonService): run = EvaluationRun.get_by_id(run_id) if not run: return {} - + results = EvaluationResult.select().where( EvaluationResult.run_id == run_id ).order_by(EvaluationResult.create_time) - + return { "run": run.to_dict(), "results": [r.to_dict() for r in results] @@ -532,15 +571,15 @@ class EvaluationService(CommonService): except Exception as e: logging.error(f"Error getting run results {run_id}: {e}") return {} - + @classmethod def get_recommendations(cls, run_id: str) -> List[Dict[str, Any]]: """ Analyze evaluation results and provide configuration recommendations. - + Args: run_id: Evaluation run ID - + Returns: List of recommendation dictionaries """ @@ -548,10 +587,10 @@ class EvaluationService(CommonService): run = EvaluationRun.get_by_id(run_id) if not run or not run.metrics_summary: return [] - + metrics = run.metrics_summary recommendations = [] - + # Low precision: retrieving irrelevant chunks if metrics.get("avg_precision", 1.0) < 0.7: recommendations.append({ @@ -564,7 +603,7 @@ class EvaluationService(CommonService): "Reduce top_k to return fewer chunks" ] }) - + # Low recall: missing relevant chunks if metrics.get("avg_recall", 1.0) < 0.7: recommendations.append({ @@ -578,7 +617,7 @@ class EvaluationService(CommonService): "Check chunk size - may be too large or too small" ] }) - + # Slow response time if metrics.get("avg_execution_time", 0) > 5.0: recommendations.append({ @@ -591,7 +630,7 @@ class EvaluationService(CommonService): "Consider caching frequently asked questions" ] }) - + return recommendations except Exception as e: logging.error(f"Error generating recommendations for run {run_id}: {e}") diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 6a63713ec..86356a7a7 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -16,15 +16,17 @@ import asyncio import inspect import logging +import queue import re import threading -from common.token_utils import num_tokens_from_string from functools import partial from typing import Generator -from common.constants import LLMType + from api.db.db_models import LLM from api.db.services.common_service import CommonService from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService +from common.constants import LLMType +from common.token_utils import num_tokens_from_string class LLMService(CommonService): @@ -33,6 +35,7 @@ class LLMService(CommonService): def get_init_tenant_llm(user_id): from common import settings + tenant_llm = [] model_configs = { @@ -193,7 +196,7 @@ class LLMBundle(LLM4Tenant): generation = self.langfuse.start_generation( trace_context=self.trace_context, name="stream_transcription", - metadata={"model": self.llm_name} + metadata={"model": self.llm_name}, ) final_text = "" used_tokens = 0 @@ -217,32 +220,34 @@ class LLMBundle(LLM4Tenant): if self.langfuse: generation.update( output={"output": final_text}, - usage_details={"total_tokens": used_tokens} + usage_details={"total_tokens": used_tokens}, ) generation.end() return if self.langfuse: - generation = self.langfuse.start_generation(trace_context=self.trace_context, name="stream_transcription", metadata={"model": self.llm_name}) - full_text, used_tokens = mdl.transcription(audio) - if not TenantLLMService.increase_usage( - self.tenant_id, self.llm_type, used_tokens - ): - logging.error( - f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}" + generation = self.langfuse.start_generation( + trace_context=self.trace_context, + name="stream_transcription", + metadata={"model": self.llm_name}, ) + + full_text, used_tokens = mdl.transcription(audio) + if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens): + logging.error(f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}") + if self.langfuse: generation.update( output={"output": full_text}, - usage_details={"total_tokens": used_tokens} + usage_details={"total_tokens": used_tokens}, ) generation.end() yield { "event": "final", "text": full_text, - "streaming": False + "streaming": False, } def tts(self, text: str) -> Generator[bytes, None, None]: @@ -289,61 +294,79 @@ class LLMBundle(LLM4Tenant): return kwargs else: return {k: v for k, v in kwargs.items() if k in allowed_params} + + def _run_coroutine_sync(self, coro): + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(coro) + + result_queue: queue.Queue = queue.Queue() + + def runner(): + try: + result_queue.put((True, asyncio.run(coro))) + except Exception as e: + result_queue.put((False, e)) + + thread = threading.Thread(target=runner, daemon=True) + thread.start() + thread.join() + + success, value = result_queue.get_nowait() + if success: + return value + raise value + def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str: - if self.langfuse: - generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history}) + return self._run_coroutine_sync(self.async_chat(system, history, gen_conf, **kwargs)) - chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs) - if self.is_tools and self.mdl.is_tools: - chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs) + def _sync_from_async_stream(self, async_gen_fn, *args, **kwargs): + result_queue: queue.Queue = queue.Queue() - use_kwargs = self._clean_param(chat_partial, **kwargs) - txt, used_tokens = chat_partial(**use_kwargs) - txt = self._remove_reasoning_content(txt) + def runner(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) - if not self.verbose_tool_use: - txt = re.sub(r".*?", "", txt, flags=re.DOTALL) + async def consume(): + try: + async for item in async_gen_fn(*args, **kwargs): + result_queue.put(item) + except Exception as e: + result_queue.put(e) + finally: + result_queue.put(StopIteration) - if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name): - logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens)) + loop.run_until_complete(consume()) + loop.close() - if self.langfuse: - generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens}) - generation.end() + threading.Thread(target=runner, daemon=True).start() - return txt + while True: + item = result_queue.get() + if item is StopIteration: + break + if isinstance(item, Exception): + raise item + yield item def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs): - if self.langfuse: - generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history}) - ans = "" - chat_partial = partial(self.mdl.chat_streamly, system, history, gen_conf) - total_tokens = 0 - if self.is_tools and self.mdl.is_tools: - chat_partial = partial(self.mdl.chat_streamly_with_tools, system, history, gen_conf) - use_kwargs = self._clean_param(chat_partial, **kwargs) - for txt in chat_partial(**use_kwargs): + for txt in self._sync_from_async_stream(self.async_chat_streamly, system, history, gen_conf, **kwargs): if isinstance(txt, int): - total_tokens = txt - if self.langfuse: - generation.update(output={"output": ans}) - generation.end() break if txt.endswith(""): - ans = ans[: -len("")] + ans = txt[: -len("")] + continue if not self.verbose_tool_use: txt = re.sub(r".*?", "", txt, flags=re.DOTALL) - ans += txt + # cancatination has beend done in async_chat_streamly + ans = txt yield ans - if total_tokens > 0: - if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name): - logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens)) - def _bridge_sync_stream(self, gen): loop = asyncio.get_running_loop() queue: asyncio.Queue = asyncio.Queue() @@ -352,7 +375,7 @@ class LLMBundle(LLM4Tenant): try: for item in gen: loop.call_soon_threadsafe(queue.put_nowait, item) - except Exception as e: # pragma: no cover + except Exception as e: loop.call_soon_threadsafe(queue.put_nowait, e) finally: loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration) @@ -361,18 +384,27 @@ class LLMBundle(LLM4Tenant): return queue async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs): - chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs) - if self.is_tools and self.mdl.is_tools and hasattr(self.mdl, "chat_with_tools"): - chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs) + if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_with_tools"): + base_fn = self.mdl.async_chat_with_tools + elif hasattr(self.mdl, "async_chat"): + base_fn = self.mdl.async_chat + else: + raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools") + generation = None + if self.langfuse: + generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history}) + + chat_partial = partial(base_fn, system, history, gen_conf) use_kwargs = self._clean_param(chat_partial, **kwargs) - if hasattr(self.mdl, "async_chat_with_tools") and self.is_tools and self.mdl.is_tools: - txt, used_tokens = await self.mdl.async_chat_with_tools(system, history, gen_conf, **use_kwargs) - elif hasattr(self.mdl, "async_chat"): - txt, used_tokens = await self.mdl.async_chat(system, history, gen_conf, **use_kwargs) - else: - txt, used_tokens = await asyncio.to_thread(chat_partial, **use_kwargs) + try: + txt, used_tokens = await chat_partial(**use_kwargs) + except Exception as e: + if generation: + generation.update(output={"error": str(e)}) + generation.end() + raise txt = self._remove_reasoning_content(txt) if not self.verbose_tool_use: @@ -381,49 +413,51 @@ class LLMBundle(LLM4Tenant): if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name): logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens)) + if generation: + generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens}) + generation.end() + return txt async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs): total_tokens = 0 ans = "" - if self.is_tools and self.mdl.is_tools: + if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_streamly_with_tools"): stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None) - else: + elif hasattr(self.mdl, "async_chat_streamly"): stream_fn = getattr(self.mdl, "async_chat_streamly", None) + else: + raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools") + + generation = None + if self.langfuse: + generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history}) if stream_fn: chat_partial = partial(stream_fn, system, history, gen_conf) use_kwargs = self._clean_param(chat_partial, **kwargs) - async for txt in chat_partial(**use_kwargs): - if isinstance(txt, int): - total_tokens = txt - break + try: + async for txt in chat_partial(**use_kwargs): + if isinstance(txt, int): + total_tokens = txt + break - if txt.endswith(""): - ans = ans[: -len("")] + if txt.endswith(""): + ans = ans[: -len("")] - if not self.verbose_tool_use: - txt = re.sub(r".*?", "", txt, flags=re.DOTALL) + if not self.verbose_tool_use: + txt = re.sub(r".*?", "", txt, flags=re.DOTALL) - ans += txt - yield ans + ans += txt + yield ans + except Exception as e: + if generation: + generation.update(output={"error": str(e)}) + generation.end() + raise if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name): logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens)) + if generation: + generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens}) + generation.end() return - - chat_partial = partial(self.mdl.chat_streamly_with_tools if (self.is_tools and self.mdl.is_tools) else self.mdl.chat_streamly, system, history, gen_conf) - use_kwargs = self._clean_param(chat_partial, **kwargs) - queue = self._bridge_sync_stream(chat_partial(**use_kwargs)) - while True: - item = await queue.get() - if item is StopAsyncIteration: - break - if isinstance(item, Exception): - raise item - if isinstance(item, int): - total_tokens = item - break - yield item - - if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name): - logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens)) diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 3ff5311fc..67bf0bb09 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -52,6 +52,8 @@ class SupportedLiteLLMProvider(StrEnum): JiekouAI = "Jiekou.AI" ZHIPU_AI = "ZHIPU-AI" MiniMax = "MiniMax" + DeerAPI = "DeerAPI" + GPUStack = "GPUStack" FACTORY_DEFAULT_BASE_URL = { @@ -75,6 +77,7 @@ FACTORY_DEFAULT_BASE_URL = { SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai", SupportedLiteLLMProvider.ZHIPU_AI: "https://open.bigmodel.cn/api/paas/v4", SupportedLiteLLMProvider.MiniMax: "https://api.minimaxi.com/v1", + SupportedLiteLLMProvider.DeerAPI: "https://api.deerapi.com/v1", } @@ -108,6 +111,8 @@ LITELLM_PROVIDER_PREFIX = { SupportedLiteLLMProvider.JiekouAI: "openai/", SupportedLiteLLMProvider.ZHIPU_AI: "openai/", SupportedLiteLLMProvider.MiniMax: "openai/", + SupportedLiteLLMProvider.DeerAPI: "openai/", + SupportedLiteLLMProvider.GPUStack: "openai/", } ChatModel = globals().get("ChatModel", {}) diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index e69ff1868..9f5457224 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -19,7 +19,6 @@ import logging import os import random import re -import threading import time from abc import ABC from copy import deepcopy @@ -78,11 +77,9 @@ class Base(ABC): self.toolcall_sessions = {} def _get_delay(self): - """Calculate retry delay time""" return self.base_delay * random.uniform(10, 150) def _classify_error(self, error): - """Classify error based on error message content""" error_str = str(error).lower() keywords_mapping = [ @@ -139,89 +136,7 @@ class Base(ABC): return gen_conf - def _bridge_sync_stream(self, gen): - """Run a sync generator in a thread and yield asynchronously.""" - loop = asyncio.get_running_loop() - queue: asyncio.Queue = asyncio.Queue() - - def worker(): - try: - for item in gen: - loop.call_soon_threadsafe(queue.put_nowait, item) - except Exception as exc: # pragma: no cover - defensive - loop.call_soon_threadsafe(queue.put_nowait, exc) - finally: - loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration) - - threading.Thread(target=worker, daemon=True).start() - return queue - - def _chat(self, history, gen_conf, **kwargs): - logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2)) - if self.model_name.lower().find("qwq") >= 0: - logging.info(f"[INFO] {self.model_name} detected as reasoning model, using _chat_streamly") - - final_ans = "" - tol_token = 0 - for delta, tol in self._chat_streamly(history, gen_conf, with_reasoning=False, **kwargs): - if delta.startswith("") or delta.endswith(""): - continue - final_ans += delta - tol_token = tol - - if len(final_ans.strip()) == 0: - final_ans = "**ERROR**: Empty response from reasoning model" - - return final_ans.strip(), tol_token - - if self.model_name.lower().find("qwen3") >= 0: - kwargs["extra_body"] = {"enable_thinking": False} - - response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs) - - if not response.choices or not response.choices[0].message or not response.choices[0].message.content: - return "", 0 - ans = response.choices[0].message.content.strip() - if response.choices[0].finish_reason == "length": - ans = self._length_stop(ans) - return ans, total_token_count_from_response(response) - - def _chat_streamly(self, history, gen_conf, **kwargs): - logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4)) - reasoning_start = False - - if kwargs.get("stop") or "stop" in gen_conf: - response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf, stop=kwargs.get("stop")) - else: - response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf) - - for resp in response: - if not resp.choices: - continue - if not resp.choices[0].delta.content: - resp.choices[0].delta.content = "" - if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content: - ans = "" - if not reasoning_start: - reasoning_start = True - ans = "" - ans += resp.choices[0].delta.reasoning_content + "" - else: - reasoning_start = False - ans = resp.choices[0].delta.content - - tol = total_token_count_from_response(resp) - if not tol: - tol = num_tokens_from_string(resp.choices[0].delta.content) - - if resp.choices[0].finish_reason == "length": - if is_chinese(ans): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - yield ans, tol - - async def _async_chat_stream(self, history, gen_conf, **kwargs): + async def _async_chat_streamly(self, history, gen_conf, **kwargs): logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4)) reasoning_start = False @@ -265,13 +180,19 @@ class Base(ABC): gen_conf = self._clean_conf(gen_conf) ans = "" total_tokens = 0 - try: - async for delta_ans, tol in self._async_chat_stream(history, gen_conf, **kwargs): - ans = delta_ans - total_tokens += tol - yield delta_ans - except openai.APIError as e: - yield ans + "\n**ERROR**: " + str(e) + + for attempt in range(self.max_retries + 1): + try: + async for delta_ans, tol in self._async_chat_streamly(history, gen_conf, **kwargs): + ans = delta_ans + total_tokens += tol + yield ans + except Exception as e: + e = await self._exceptions_async(e, attempt) + if e: + yield e + yield total_tokens + return yield total_tokens @@ -307,7 +228,7 @@ class Base(ABC): logging.error(f"sync base giving up: {msg}") return msg - async def _exceptions_async(self, e, attempt) -> str | None: + async def _exceptions_async(self, e, attempt): logging.exception("OpenAI async completion") error_code = self._classify_error(e) if attempt == self.max_retries: @@ -357,61 +278,6 @@ class Base(ABC): self.toolcall_session = toolcall_session self.tools = tools - def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}): - gen_conf = self._clean_conf(gen_conf) - if system and history and history[0].get("role") != "system": - history.insert(0, {"role": "system", "content": system}) - - ans = "" - tk_count = 0 - hist = deepcopy(history) - # Implement exponential backoff retry strategy - for attempt in range(self.max_retries + 1): - history = hist - try: - for _ in range(self.max_rounds + 1): - logging.info(f"{self.tools=}") - response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf) - tk_count += total_token_count_from_response(response) - if any([not response.choices, not response.choices[0].message]): - raise Exception(f"500 response structure error. Response: {response}") - - if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls: - if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content: - ans += "" + response.choices[0].message.reasoning_content + "" - - ans += response.choices[0].message.content - if response.choices[0].finish_reason == "length": - ans = self._length_stop(ans) - - return ans, tk_count - - for tool_call in response.choices[0].message.tool_calls: - logging.info(f"Response {tool_call=}") - name = tool_call.function.name - try: - args = json_repair.loads(tool_call.function.arguments) - tool_response = self.toolcall_session.tool_call(name, args) - history = self._append_history(history, tool_call, tool_response) - ans += self._verbose_tool_use(name, args, tool_response) - except Exception as e: - logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") - history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) - ans += self._verbose_tool_use(name, {}, str(e)) - - logging.warning(f"Exceed max rounds: {self.max_rounds}") - history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) - response, token_count = self._chat(history, gen_conf) - ans += response - tk_count += token_count - return ans, tk_count - except Exception as e: - e = self._exceptions(e, attempt) - if e: - return e, tk_count - - assert False, "Shouldn't be here." - async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}): gen_conf = self._clean_conf(gen_conf) if system and history and history[0].get("role") != "system": @@ -466,140 +332,6 @@ class Base(ABC): assert False, "Shouldn't be here." - def chat(self, system, history, gen_conf={}, **kwargs): - if system and history and history[0].get("role") != "system": - history.insert(0, {"role": "system", "content": system}) - gen_conf = self._clean_conf(gen_conf) - - # Implement exponential backoff retry strategy - for attempt in range(self.max_retries + 1): - try: - return self._chat(history, gen_conf, **kwargs) - except Exception as e: - e = self._exceptions(e, attempt) - if e: - return e, 0 - assert False, "Shouldn't be here." - - def _wrap_toolcall_message(self, stream): - final_tool_calls = {} - - for chunk in stream: - for tool_call in chunk.choices[0].delta.tool_calls or []: - index = tool_call.index - - if index not in final_tool_calls: - final_tool_calls[index] = tool_call - - final_tool_calls[index].function.arguments += tool_call.function.arguments - - return final_tool_calls - - def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}): - gen_conf = self._clean_conf(gen_conf) - tools = self.tools - if system and history and history[0].get("role") != "system": - history.insert(0, {"role": "system", "content": system}) - - total_tokens = 0 - hist = deepcopy(history) - # Implement exponential backoff retry strategy - for attempt in range(self.max_retries + 1): - history = hist - try: - for _ in range(self.max_rounds + 1): - reasoning_start = False - logging.info(f"{tools=}") - response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf) - final_tool_calls = {} - answer = "" - for resp in response: - if resp.choices[0].delta.tool_calls: - for tool_call in resp.choices[0].delta.tool_calls or []: - index = tool_call.index - - if index not in final_tool_calls: - if not tool_call.function.arguments: - tool_call.function.arguments = "" - final_tool_calls[index] = tool_call - else: - final_tool_calls[index].function.arguments += tool_call.function.arguments if tool_call.function.arguments else "" - continue - - if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]): - raise Exception("500 response structure error.") - - if not resp.choices[0].delta.content: - resp.choices[0].delta.content = "" - - if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content: - ans = "" - if not reasoning_start: - reasoning_start = True - ans = "" - ans += resp.choices[0].delta.reasoning_content + "" - yield ans - else: - reasoning_start = False - answer += resp.choices[0].delta.content - yield resp.choices[0].delta.content - - tol = total_token_count_from_response(resp) - if not tol: - total_tokens += num_tokens_from_string(resp.choices[0].delta.content) - else: - total_tokens = tol - - finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else "" - if finish_reason == "length": - yield self._length_stop("") - - if answer: - yield total_tokens - return - - for tool_call in final_tool_calls.values(): - name = tool_call.function.name - try: - args = json_repair.loads(tool_call.function.arguments) - yield self._verbose_tool_use(name, args, "Begin to call...") - tool_response = self.toolcall_session.tool_call(name, args) - history = self._append_history(history, tool_call, tool_response) - yield self._verbose_tool_use(name, args, tool_response) - except Exception as e: - logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") - history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) - yield self._verbose_tool_use(name, {}, str(e)) - - logging.warning(f"Exceed max rounds: {self.max_rounds}") - history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) - response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf) - for resp in response: - if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]): - raise Exception("500 response structure error.") - if not resp.choices[0].delta.content: - resp.choices[0].delta.content = "" - continue - tol = total_token_count_from_response(resp) - if not tol: - total_tokens += num_tokens_from_string(resp.choices[0].delta.content) - else: - total_tokens = tol - answer += resp.choices[0].delta.content - yield resp.choices[0].delta.content - - yield total_tokens - return - - except Exception as e: - e = self._exceptions(e, attempt) - if e: - yield e - yield total_tokens - return - - assert False, "Shouldn't be here." - async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}): gen_conf = self._clean_conf(gen_conf) tools = self.tools @@ -715,9 +447,10 @@ class Base(ABC): logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2)) if self.model_name.lower().find("qwq") >= 0: logging.info(f"[INFO] {self.model_name} detected as reasoning model, using async_chat_streamly") + final_ans = "" tol_token = 0 - async for delta, tol in self._async_chat_stream(history, gen_conf, with_reasoning=False, **kwargs): + async for delta, tol in self._async_chat_streamly(history, gen_conf, with_reasoning=False, **kwargs): if delta.startswith("") or delta.endswith(""): continue final_ans += delta @@ -754,57 +487,6 @@ class Base(ABC): return e, 0 assert False, "Shouldn't be here." - def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs): - if system and history and history[0].get("role") != "system": - history.insert(0, {"role": "system", "content": system}) - gen_conf = self._clean_conf(gen_conf) - ans = "" - total_tokens = 0 - try: - for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs): - yield delta_ans - total_tokens += tol - except openai.APIError as e: - yield ans + "\n**ERROR**: " + str(e) - - yield total_tokens - - def _calculate_dynamic_ctx(self, history): - """Calculate dynamic context window size""" - - def count_tokens(text): - """Calculate token count for text""" - # Simple calculation: 1 token per ASCII character - # 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.) - total = 0 - for char in text: - if ord(char) < 128: # ASCII characters - total += 1 - else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.) - total += 2 - return total - - # Calculate total tokens for all messages - total_tokens = 0 - for message in history: - content = message.get("content", "") - # Calculate content tokens - content_tokens = count_tokens(content) - # Add role marker token overhead - role_tokens = 4 - total_tokens += content_tokens + role_tokens - - # Apply 1.2x buffer ratio - total_tokens_with_buffer = int(total_tokens * 1.2) - - if total_tokens_with_buffer <= 8192: - ctx_size = 8192 - else: - ctx_multiplier = (total_tokens_with_buffer // 8192) + 1 - ctx_size = ctx_multiplier * 8192 - - return ctx_size - class GptTurbo(Base): _FACTORY_NAME = "OpenAI" @@ -1504,16 +1186,6 @@ class GoogleChat(Base): yield total_tokens -class GPUStackChat(Base): - _FACTORY_NAME = "GPUStack" - - def __init__(self, key=None, model_name="", base_url="", **kwargs): - if not base_url: - raise ValueError("Local llm url cannot be None") - base_url = urljoin(base_url, "v1") - super().__init__(key, model_name, base_url, **kwargs) - - class TokenPonyChat(Base): _FACTORY_NAME = "TokenPony" @@ -1523,15 +1195,6 @@ class TokenPonyChat(Base): super().__init__(key, model_name, base_url, **kwargs) -class DeerAPIChat(Base): - _FACTORY_NAME = "DeerAPI" - - def __init__(self, key, model_name, base_url="https://api.deerapi.com/v1", **kwargs): - if not base_url: - base_url = "https://api.deerapi.com/v1" - super().__init__(key, model_name, base_url, **kwargs) - - class LiteLLMBase(ABC): _FACTORY_NAME = [ "Tongyi-Qianwen", @@ -1562,6 +1225,8 @@ class LiteLLMBase(ABC): "Jiekou.AI", "ZHIPU-AI", "MiniMax", + "DeerAPI", + "GPUStack", ] def __init__(self, key, model_name, base_url=None, **kwargs): @@ -1589,11 +1254,9 @@ class LiteLLMBase(ABC): self.provider_order = json.loads(key).get("provider_order", "") def _get_delay(self): - """Calculate retry delay time""" return self.base_delay * random.uniform(10, 150) def _classify_error(self, error): - """Classify error based on error message content""" error_str = str(error).lower() keywords_mapping = [ @@ -1619,72 +1282,6 @@ class LiteLLMBase(ABC): del gen_conf["max_tokens"] return gen_conf - def _chat(self, history, gen_conf, **kwargs): - logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2)) - if self.model_name.lower().find("qwen3") >= 0: - kwargs["extra_body"] = {"enable_thinking": False} - - completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf) - response = litellm.completion( - **completion_args, - drop_params=True, - timeout=self.timeout, - ) - # response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs) - if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]): - return "", 0 - ans = response.choices[0].message.content.strip() - if response.choices[0].finish_reason == "length": - ans = self._length_stop(ans) - - return ans, total_token_count_from_response(response) - - def _chat_streamly(self, history, gen_conf, **kwargs): - logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4)) - gen_conf = self._clean_conf(gen_conf) - reasoning_start = False - - completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf) - stop = kwargs.get("stop") - if stop: - completion_args["stop"] = stop - response = litellm.completion( - **completion_args, - drop_params=True, - timeout=self.timeout, - ) - - for resp in response: - if not hasattr(resp, "choices") or not resp.choices: - continue - - delta = resp.choices[0].delta - if not hasattr(delta, "content") or delta.content is None: - delta.content = "" - - if kwargs.get("with_reasoning", True) and hasattr(delta, "reasoning_content") and delta.reasoning_content: - ans = "" - if not reasoning_start: - reasoning_start = True - ans = "" - ans += delta.reasoning_content + "" - else: - reasoning_start = False - ans = delta.content - - tol = total_token_count_from_response(resp) - if not tol: - tol = num_tokens_from_string(delta.content) - - finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else "" - if finish_reason == "length": - if is_chinese(ans): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - - yield ans, tol - async def async_chat(self, system, history, gen_conf, **kwargs): hist = list(history) if history else [] if system: @@ -1795,22 +1392,7 @@ class LiteLLMBase(ABC): def _should_retry(self, error_code: str) -> bool: return error_code in self._retryable_errors - def _exceptions(self, e, attempt) -> str | None: - logging.exception("OpenAI chat_with_tools") - # Classify the error - error_code = self._classify_error(e) - if attempt == self.max_retries: - error_code = LLMErrorCode.ERROR_MAX_RETRIES - - if self._should_retry(error_code): - delay = self._get_delay() - logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})") - time.sleep(delay) - return None - - return f"{ERROR_PREFIX}: {error_code} - {str(e)}" - - async def _exceptions_async(self, e, attempt) -> str | None: + async def _exceptions_async(self, e, attempt): logging.exception("LiteLLMBase async completion") error_code = self._classify_error(e) if attempt == self.max_retries: @@ -1859,71 +1441,7 @@ class LiteLLMBase(ABC): self.toolcall_session = toolcall_session self.tools = tools - def _construct_completion_args(self, history, stream: bool, tools: bool, **kwargs): - completion_args = { - "model": self.model_name, - "messages": history, - "api_key": self.api_key, - "num_retries": self.max_retries, - **kwargs, - } - if stream: - completion_args.update( - { - "stream": stream, - } - ) - if tools and self.tools: - completion_args.update( - { - "tools": self.tools, - "tool_choice": "auto", - } - ) - if self.provider in FACTORY_DEFAULT_BASE_URL: - completion_args.update({"api_base": self.base_url}) - elif self.provider == SupportedLiteLLMProvider.Bedrock: - completion_args.pop("api_key", None) - completion_args.pop("api_base", None) - completion_args.update( - { - "aws_access_key_id": self.bedrock_ak, - "aws_secret_access_key": self.bedrock_sk, - "aws_region_name": self.bedrock_region, - } - ) - - if self.provider == SupportedLiteLLMProvider.OpenRouter: - if self.provider_order: - - def _to_order_list(x): - if x is None: - return [] - if isinstance(x, str): - return [s.strip() for s in x.split(",") if s.strip()] - if isinstance(x, (list, tuple)): - return [str(s).strip() for s in x if str(s).strip()] - return [] - - extra_body = {} - provider_cfg = {} - provider_order = _to_order_list(self.provider_order) - provider_cfg["order"] = provider_order - provider_cfg["allow_fallbacks"] = False - extra_body["provider"] = provider_cfg - completion_args.update({"extra_body": extra_body}) - - # Ollama deployments commonly sit behind a reverse proxy that enforces - # Bearer auth. Ensure the Authorization header is set when an API key - # is provided, while respecting any user-supplied headers. #11350 - extra_headers = deepcopy(completion_args.get("extra_headers") or {}) - if self.provider == SupportedLiteLLMProvider.Ollama and self.api_key and "Authorization" not in extra_headers: - extra_headers["Authorization"] = f"Bearer {self.api_key}" - if extra_headers: - completion_args["extra_headers"] = extra_headers - return completion_args - - def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}): + async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}): gen_conf = self._clean_conf(gen_conf) if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) @@ -1931,16 +1449,14 @@ class LiteLLMBase(ABC): ans = "" tk_count = 0 hist = deepcopy(history) - - # Implement exponential backoff retry strategy for attempt in range(self.max_retries + 1): - history = deepcopy(hist) # deepcopy is required here + history = deepcopy(hist) try: for _ in range(self.max_rounds + 1): logging.info(f"{self.tools=}") completion_args = self._construct_completion_args(history=history, stream=False, tools=True, **gen_conf) - response = litellm.completion( + response = await litellm.acompletion( **completion_args, drop_params=True, timeout=self.timeout, @@ -1966,7 +1482,7 @@ class LiteLLMBase(ABC): name = tool_call.function.name try: args = json_repair.loads(tool_call.function.arguments) - tool_response = self.toolcall_session.tool_call(name, args) + tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args) history = self._append_history(history, tool_call, tool_response) ans += self._verbose_tool_use(name, args, tool_response) except Exception as e: @@ -1977,49 +1493,19 @@ class LiteLLMBase(ABC): logging.warning(f"Exceed max rounds: {self.max_rounds}") history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) - response, token_count = self._chat(history, gen_conf) + response, token_count = await self.async_chat("", history, gen_conf) ans += response tk_count += token_count return ans, tk_count except Exception as e: - e = self._exceptions(e, attempt) + e = await self._exceptions_async(e, attempt) if e: return e, tk_count assert False, "Shouldn't be here." - def chat(self, system, history, gen_conf={}, **kwargs): - if system and history and history[0].get("role") != "system": - history.insert(0, {"role": "system", "content": system}) - gen_conf = self._clean_conf(gen_conf) - - # Implement exponential backoff retry strategy - for attempt in range(self.max_retries + 1): - try: - response = self._chat(history, gen_conf, **kwargs) - return response - except Exception as e: - e = self._exceptions(e, attempt) - if e: - return e, 0 - assert False, "Shouldn't be here." - - def _wrap_toolcall_message(self, stream): - final_tool_calls = {} - - for chunk in stream: - for tool_call in chunk.choices[0].delta.tool_calls or []: - index = tool_call.index - - if index not in final_tool_calls: - final_tool_calls[index] = tool_call - - final_tool_calls[index].function.arguments += tool_call.function.arguments - - return final_tool_calls - - def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}): + async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}): gen_conf = self._clean_conf(gen_conf) tools = self.tools if system and history and history[0].get("role") != "system": @@ -2028,16 +1514,15 @@ class LiteLLMBase(ABC): total_tokens = 0 hist = deepcopy(history) - # Implement exponential backoff retry strategy for attempt in range(self.max_retries + 1): - history = deepcopy(hist) # deepcopy is required here + history = deepcopy(hist) try: for _ in range(self.max_rounds + 1): reasoning_start = False logging.info(f"{tools=}") completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf) - response = litellm.completion( + response = await litellm.acompletion( **completion_args, drop_params=True, timeout=self.timeout, @@ -2046,7 +1531,7 @@ class LiteLLMBase(ABC): final_tool_calls = {} answer = "" - for resp in response: + async for resp in response: if not hasattr(resp, "choices") or not resp.choices: continue @@ -2082,7 +1567,7 @@ class LiteLLMBase(ABC): if not tol: total_tokens += num_tokens_from_string(delta.content) else: - total_tokens += tol + total_tokens = tol finish_reason = getattr(resp.choices[0], "finish_reason", "") if finish_reason == "length": @@ -2097,31 +1582,25 @@ class LiteLLMBase(ABC): try: args = json_repair.loads(tool_call.function.arguments) yield self._verbose_tool_use(name, args, "Begin to call...") - tool_response = self.toolcall_session.tool_call(name, args) + tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args) history = self._append_history(history, tool_call, tool_response) yield self._verbose_tool_use(name, args, tool_response) except Exception as e: logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") - history.append( - { - "role": "tool", - "tool_call_id": tool_call.id, - "content": f"Tool call error: \n{tool_call}\nException:\n{str(e)}", - } - ) + history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) yield self._verbose_tool_use(name, {}, str(e)) logging.warning(f"Exceed max rounds: {self.max_rounds}") history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf) - response = litellm.completion( + response = await litellm.acompletion( **completion_args, drop_params=True, timeout=self.timeout, ) - for resp in response: + async for resp in response: if not hasattr(resp, "choices") or not resp.choices: continue delta = resp.choices[0].delta @@ -2131,14 +1610,14 @@ class LiteLLMBase(ABC): if not tol: total_tokens += num_tokens_from_string(delta.content) else: - total_tokens += tol + total_tokens = tol yield delta.content yield total_tokens return except Exception as e: - e = self._exceptions(e, attempt) + e = await self._exceptions_async(e, attempt) if e: yield e yield total_tokens @@ -2146,53 +1625,71 @@ class LiteLLMBase(ABC): assert False, "Shouldn't be here." - def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs): - if system and history and history[0].get("role") != "system": - history.insert(0, {"role": "system", "content": system}) - gen_conf = self._clean_conf(gen_conf) - ans = "" - total_tokens = 0 - try: - for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs): - yield delta_ans - total_tokens += tol - except openai.APIError as e: - yield ans + "\n**ERROR**: " + str(e) + def _construct_completion_args(self, history, stream: bool, tools: bool, **kwargs): + completion_args = { + "model": self.model_name, + "messages": history, + "api_key": self.api_key, + "num_retries": self.max_retries, + **kwargs, + } + if stream: + completion_args.update( + { + "stream": stream, + } + ) + if tools and self.tools: + completion_args.update( + { + "tools": self.tools, + "tool_choice": "auto", + } + ) + if self.provider in FACTORY_DEFAULT_BASE_URL: + completion_args.update({"api_base": self.base_url}) + elif self.provider == SupportedLiteLLMProvider.Bedrock: + completion_args.pop("api_key", None) + completion_args.pop("api_base", None) + completion_args.update( + { + "aws_access_key_id": self.bedrock_ak, + "aws_secret_access_key": self.bedrock_sk, + "aws_region_name": self.bedrock_region, + } + ) + elif self.provider == SupportedLiteLLMProvider.OpenRouter: + if self.provider_order: - yield total_tokens + def _to_order_list(x): + if x is None: + return [] + if isinstance(x, str): + return [s.strip() for s in x.split(",") if s.strip()] + if isinstance(x, (list, tuple)): + return [str(s).strip() for s in x if str(s).strip()] + return [] - def _calculate_dynamic_ctx(self, history): - """Calculate dynamic context window size""" + extra_body = {} + provider_cfg = {} + provider_order = _to_order_list(self.provider_order) + provider_cfg["order"] = provider_order + provider_cfg["allow_fallbacks"] = False + extra_body["provider"] = provider_cfg + completion_args.update({"extra_body": extra_body}) + elif self.provider == SupportedLiteLLMProvider.GPUStack: + completion_args.update( + { + "api_base": self.base_url, + } + ) - def count_tokens(text): - """Calculate token count for text""" - # Simple calculation: 1 token per ASCII character - # 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.) - total = 0 - for char in text: - if ord(char) < 128: # ASCII characters - total += 1 - else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.) - total += 2 - return total - - # Calculate total tokens for all messages - total_tokens = 0 - for message in history: - content = message.get("content", "") - # Calculate content tokens - content_tokens = count_tokens(content) - # Add role marker token overhead - role_tokens = 4 - total_tokens += content_tokens + role_tokens - - # Apply 1.2x buffer ratio - total_tokens_with_buffer = int(total_tokens * 1.2) - - if total_tokens_with_buffer <= 8192: - ctx_size = 8192 - else: - ctx_multiplier = (total_tokens_with_buffer // 8192) + 1 - ctx_size = ctx_multiplier * 8192 - - return ctx_size + # Ollama deployments commonly sit behind a reverse proxy that enforces + # Bearer auth. Ensure the Authorization header is set when an API key + # is provided, while respecting any user-supplied headers. #11350 + extra_headers = deepcopy(completion_args.get("extra_headers") or {}) + if self.provider == SupportedLiteLLMProvider.Ollama and self.api_key and "Authorization" not in extra_headers: + extra_headers["Authorization"] = f"Bearer {self.api_key}" + if extra_headers: + completion_args["extra_headers"] = extra_headers + return completion_args