diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py
index 89630e4a4..337cb74df 100644
--- a/api/apps/conversation_app.py
+++ b/api/apps/conversation_app.py
@@ -23,7 +23,7 @@ from quart import Response, request
from api.apps import current_user, login_required
from api.db.db_models import APIToken
from api.db.services.conversation_service import ConversationService, structure_answer
-from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap
+from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap
from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService
from api.db.services.tenant_llm_service import TenantLLMService
@@ -218,10 +218,10 @@ async def completion():
dia.llm_setting = chat_model_config
is_embedded = bool(chat_model_id)
- def stream():
+ async def stream():
nonlocal dia, msg, req, conv
try:
- for ans in chat(dia, msg, True, **req):
+ async for ans in async_chat(dia, msg, True, **req):
ans = structure_answer(conv, ans, message_id, conv.id)
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
if not is_embedded:
@@ -241,7 +241,7 @@ async def completion():
else:
answer = None
- for ans in chat(dia, msg, **req):
+ async for ans in async_chat(dia, msg, **req):
answer = structure_answer(conv, ans, message_id, conv.id)
if not is_embedded:
ConversationService.update_by_id(conv.id, conv.to_dict())
@@ -406,10 +406,10 @@ async def ask_about():
if search_app:
search_config = search_app.get("search_config", {})
- def stream():
+ async def stream():
nonlocal req, uid
try:
- for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config):
+ async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
diff --git a/api/apps/langfuse_app.py b/api/apps/langfuse_app.py
index 8a05c0d4c..1d7993d36 100644
--- a/api/apps/langfuse_app.py
+++ b/api/apps/langfuse_app.py
@@ -34,8 +34,9 @@ async def set_api_key():
if not all([secret_key, public_key, host]):
return get_error_data_result(message="Missing required fields")
+ current_user_id = current_user.id
langfuse_keys = dict(
- tenant_id=current_user.id,
+ tenant_id=current_user_id,
secret_key=secret_key,
public_key=public_key,
host=host,
@@ -45,23 +46,24 @@ async def set_api_key():
if not langfuse.auth_check():
return get_error_data_result(message="Invalid Langfuse keys")
- langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user.id)
+ langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user_id)
with DB.atomic():
try:
if not langfuse_entry:
TenantLangfuseService.save(**langfuse_keys)
else:
- TenantLangfuseService.update_by_tenant(tenant_id=current_user.id, langfuse_keys=langfuse_keys)
+ TenantLangfuseService.update_by_tenant(tenant_id=current_user_id, langfuse_keys=langfuse_keys)
return get_json_result(data=langfuse_keys)
except Exception as e:
- server_error_response(e)
+ return server_error_response(e)
@manager.route("/api_key", methods=["GET"]) # noqa: F821
@login_required
@validate_request()
def get_api_key():
- langfuse_entry = TenantLangfuseService.filter_by_tenant_with_info(tenant_id=current_user.id)
+ current_user_id = current_user.id
+ langfuse_entry = TenantLangfuseService.filter_by_tenant_with_info(tenant_id=current_user_id)
if not langfuse_entry:
return get_json_result(message="Have not record any Langfuse keys.")
@@ -72,7 +74,7 @@ def get_api_key():
except langfuse.api.core.api_error.ApiError as api_err:
return get_json_result(message=f"Error from Langfuse: {api_err}")
except Exception as e:
- server_error_response(e)
+ return server_error_response(e)
langfuse_entry["project_id"] = langfuse.api.projects.get().dict()["data"][0]["id"]
langfuse_entry["project_name"] = langfuse.api.projects.get().dict()["data"][0]["name"]
@@ -84,7 +86,8 @@ def get_api_key():
@login_required
@validate_request()
def delete_api_key():
- langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user.id)
+ current_user_id = current_user.id
+ langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user_id)
if not langfuse_entry:
return get_json_result(message="Have not record any Langfuse keys.")
@@ -93,4 +96,4 @@ def delete_api_key():
TenantLangfuseService.delete_model(langfuse_entry)
return get_json_result(data=True)
except Exception as e:
- server_error_response(e)
+ return server_error_response(e)
diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py
index 018fb4bca..d24a4bb44 100644
--- a/api/apps/llm_app.py
+++ b/api/apps/llm_app.py
@@ -74,7 +74,7 @@ async def set_api_key():
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
try:
- m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50})
+ m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50})
if m.find("**ERROR**") >= 0:
raise Exception(m)
chat_passed = True
@@ -217,7 +217,7 @@ async def add_llm():
**extra,
)
try:
- m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
+ m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
if not tc and m.find("**ERROR**:") >= 0:
raise Exception(m)
except Exception as e:
diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py
index e94f14fcc..fe4723984 100644
--- a/api/apps/sdk/session.py
+++ b/api/apps/sdk/session.py
@@ -26,9 +26,10 @@ from api.db.db_models import APIToken
from api.db.services.api_service import API4ConversationService
from api.db.services.canvas_service import UserCanvasService, completion_openai
from api.db.services.canvas_service import completion as agent_completion
-from api.db.services.conversation_service import ConversationService, iframe_completion
-from api.db.services.conversation_service import completion as rag_completion
-from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap, meta_filter
+from api.db.services.conversation_service import ConversationService
+from api.db.services.conversation_service import async_iframe_completion as iframe_completion
+from api.db.services.conversation_service import async_completion as rag_completion
+from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap, meta_filter
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
@@ -141,7 +142,7 @@ async def chat_completion(tenant_id, chat_id):
return resp
else:
answer = None
- for ans in rag_completion(tenant_id, chat_id, **req):
+ async for ans in rag_completion(tenant_id, chat_id, **req):
answer = ans
break
return get_result(data=answer)
@@ -245,7 +246,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
# The value for the usage field on all chunks except for the last one will be null.
# The usage field on the last chunk contains token usage statistics for the entire request.
# The choices field on the last chunk will always be an empty array [].
- def streamed_response_generator(chat_id, dia, msg):
+ async def streamed_response_generator(chat_id, dia, msg):
token_used = 0
answer_cache = ""
reasoning_cache = ""
@@ -274,7 +275,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
}
try:
- for ans in chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
+ async for ans in async_chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
last_ans = ans
answer = ans["answer"]
@@ -342,7 +343,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
return resp
else:
answer = None
- for ans in chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
+ async for ans in async_chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
# focus answer content only
answer = ans
break
@@ -733,10 +734,10 @@ async def ask_about(tenant_id):
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
uid = tenant_id
- def stream():
+ async def stream():
nonlocal req, uid
try:
- for ans in ask(req["question"], req["kb_ids"], uid):
+ async for ans in async_ask(req["question"], req["kb_ids"], uid):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e:
yield "data:" + json.dumps(
@@ -827,7 +828,7 @@ async def chatbot_completions(dialog_id):
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
- for answer in iframe_completion(dialog_id, **req):
+ async for answer in iframe_completion(dialog_id, **req):
return get_result(data=answer)
@@ -918,10 +919,10 @@ async def ask_about_embedded():
if search_app := SearchService.get_detail(search_id):
search_config = search_app.get("search_config", {})
- def stream():
+ async def stream():
nonlocal req, uid
try:
- for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config):
+ async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e:
yield "data:" + json.dumps(
diff --git a/api/db/services/conversation_service.py b/api/db/services/conversation_service.py
index 60f8e55b1..aaec72bf5 100644
--- a/api/db/services/conversation_service.py
+++ b/api/db/services/conversation_service.py
@@ -19,7 +19,7 @@ from common.constants import StatusEnum
from api.db.db_models import Conversation, DB
from api.db.services.api_service import API4ConversationService
from api.db.services.common_service import CommonService
-from api.db.services.dialog_service import DialogService, chat
+from api.db.services.dialog_service import DialogService, async_chat
from common.misc_utils import get_uuid
import json
@@ -89,8 +89,7 @@ def structure_answer(conv, ans, message_id, session_id):
conv.reference[-1] = reference
return ans
-
-def completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
+async def async_completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
assert name, "`name` can not be empty."
dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
assert dia, "You do not own the chat."
@@ -112,7 +111,7 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None
"reference": {},
"audio_binary": None,
"id": None,
- "session_id": session_id
+ "session_id": session_id
}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
@@ -148,7 +147,7 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None
if stream:
try:
- for ans in chat(dia, msg, True, **kwargs):
+ async for ans in async_chat(dia, msg, True, **kwargs):
ans = structure_answer(conv, ans, message_id, session_id)
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
ConversationService.update_by_id(conv.id, conv.to_dict())
@@ -160,14 +159,13 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None
else:
answer = None
- for ans in chat(dia, msg, False, **kwargs):
+ async for ans in async_chat(dia, msg, False, **kwargs):
answer = structure_answer(conv, ans, message_id, session_id)
ConversationService.update_by_id(conv.id, conv.to_dict())
break
yield answer
-
-def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs):
+async def async_iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs):
e, dia = DialogService.get_by_id(dialog_id)
assert e, "Dialog not found"
if not session_id:
@@ -222,7 +220,7 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg
if stream:
try:
- for ans in chat(dia, msg, True, **kwargs):
+ async for ans in async_chat(dia, msg, True, **kwargs):
ans = structure_answer(conv, ans, message_id, session_id)
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
ensure_ascii=False) + "\n\n"
@@ -235,7 +233,7 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg
else:
answer = None
- for ans in chat(dia, msg, False, **kwargs):
+ async for ans in async_chat(dia, msg, False, **kwargs):
answer = structure_answer(conv, ans, message_id, session_id)
API4ConversationService.append_message(conv.id, conv.to_dict())
break
diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py
index 4afdd1f3c..43e345cd2 100644
--- a/api/db/services/dialog_service.py
+++ b/api/db/services/dialog_service.py
@@ -178,7 +178,8 @@ class DialogService(CommonService):
offset += limit
return res
-def chat_solo(dialog, messages, stream=True):
+
+async def async_chat_solo(dialog, messages, stream=True):
attachments = ""
if "files" in messages[-1]:
attachments = "\n\n".join(FileService.get_files(messages[-1]["files"]))
@@ -197,7 +198,8 @@ def chat_solo(dialog, messages, stream=True):
if stream:
last_ans = ""
delta_ans = ""
- for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
+ answer = ""
+ async for ans in chat_mdl.async_chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
answer = ans
delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 16:
@@ -208,7 +210,7 @@ def chat_solo(dialog, messages, stream=True):
if delta_ans:
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
else:
- answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
+ answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
user_content = msg[-1].get("content", "[content not available]")
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
@@ -347,13 +349,12 @@ def meta_filter(metas: dict, filters: list[dict], logic: str = "and"):
return []
return list(doc_ids)
-
-def chat(dialog, messages, stream=True, **kwargs):
+async def async_chat(dialog, messages, stream=True, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
- for ans in chat_solo(dialog, messages, stream):
+ async for ans in async_chat_solo(dialog, messages, stream):
yield ans
- return None
+ return
chat_start_ts = timer()
@@ -400,7 +401,7 @@ def chat(dialog, messages, stream=True, **kwargs):
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
if ans:
yield ans
- return None
+ return
for p in prompt_config["parameters"]:
if p["key"] == "knowledge":
@@ -508,7 +509,8 @@ def chat(dialog, messages, stream=True, **kwargs):
empty_res = prompt_config["empty_response"]
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
"audio_binary": tts(tts_mdl, empty_res)}
- return {"answer": prompt_config["empty_response"], "reference": kbinfos}
+ yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
+ return
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
gen_conf = dialog.llm_setting
@@ -612,7 +614,7 @@ def chat(dialog, messages, stream=True, **kwargs):
if stream:
last_ans = ""
answer = ""
- for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
+ async for ans in chat_mdl.async_chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
if thought:
ans = re.sub(r"^.*", "", ans, flags=re.DOTALL)
answer = ans
@@ -626,19 +628,19 @@ def chat(dialog, messages, stream=True, **kwargs):
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
yield decorate_answer(thought + answer)
else:
- answer = chat_mdl.chat(prompt + prompt4citation, msg[1:], gen_conf)
+ answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf)
user_content = msg[-1].get("content", "[content not available]")
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
res = decorate_answer(answer)
res["audio_binary"] = tts(tts_mdl, answer)
yield res
- return None
+ return
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
sys_prompt = """
-You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question.
+You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question.
Ensure that:
1. Field names should not start with a digit. If any field name starts with a digit, use double quotes around it.
2. Write only the SQL, no explanations or additional text.
@@ -805,8 +807,7 @@ def tts(tts_mdl, text):
return None
return binascii.hexlify(bin).decode("utf-8")
-
-def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
+async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
doc_ids = search_config.get("doc_ids", [])
rerank_mdl = None
kb_ids = search_config.get("kb_ids", kb_ids)
@@ -880,7 +881,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
return {"answer": answer, "reference": refs}
answer = ""
- for ans in chat_mdl.chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
+ async for ans in chat_mdl.async_chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
answer = ans
yield {"answer": answer, "reference": {}}
yield decorate_answer(answer)
diff --git a/api/db/services/evaluation_service.py b/api/db/services/evaluation_service.py
index 81b4c44fe..c5a24176d 100644
--- a/api/db/services/evaluation_service.py
+++ b/api/db/services/evaluation_service.py
@@ -25,14 +25,17 @@ Provides functionality for evaluating RAG system performance including:
- Configuration recommendations
"""
+import asyncio
import logging
+import queue
+import threading
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from timeit import default_timer as timer
from api.db.db_models import EvaluationDataset, EvaluationCase, EvaluationRun, EvaluationResult
from api.db.services.common_service import CommonService
-from api.db.services.dialog_service import DialogService, chat
+from api.db.services.dialog_service import DialogService
from common.misc_utils import get_uuid
from common.time_utils import current_timestamp
from common.constants import StatusEnum
@@ -40,24 +43,24 @@ from common.constants import StatusEnum
class EvaluationService(CommonService):
"""Service for managing RAG evaluations"""
-
+
model = EvaluationDataset
-
+
# ==================== Dataset Management ====================
-
+
@classmethod
- def create_dataset(cls, name: str, description: str, kb_ids: List[str],
+ def create_dataset(cls, name: str, description: str, kb_ids: List[str],
tenant_id: str, user_id: str) -> Tuple[bool, str]:
"""
Create a new evaluation dataset.
-
+
Args:
name: Dataset name
description: Dataset description
kb_ids: List of knowledge base IDs to evaluate against
tenant_id: Tenant ID
user_id: User ID who creates the dataset
-
+
Returns:
(success, dataset_id or error_message)
"""
@@ -74,15 +77,15 @@ class EvaluationService(CommonService):
"update_time": current_timestamp(),
"status": StatusEnum.VALID.value
}
-
+
if not EvaluationDataset.create(**dataset):
return False, "Failed to create dataset"
-
+
return True, dataset_id
except Exception as e:
logging.error(f"Error creating evaluation dataset: {e}")
return False, str(e)
-
+
@classmethod
def get_dataset(cls, dataset_id: str) -> Optional[Dict[str, Any]]:
"""Get dataset by ID"""
@@ -94,9 +97,9 @@ class EvaluationService(CommonService):
except Exception as e:
logging.error(f"Error getting dataset {dataset_id}: {e}")
return None
-
+
@classmethod
- def list_datasets(cls, tenant_id: str, user_id: str,
+ def list_datasets(cls, tenant_id: str, user_id: str,
page: int = 1, page_size: int = 20) -> Dict[str, Any]:
"""List datasets for a tenant"""
try:
@@ -104,10 +107,10 @@ class EvaluationService(CommonService):
(EvaluationDataset.tenant_id == tenant_id) &
(EvaluationDataset.status == StatusEnum.VALID.value)
).order_by(EvaluationDataset.create_time.desc())
-
+
total = query.count()
datasets = query.paginate(page, page_size)
-
+
return {
"total": total,
"datasets": [d.to_dict() for d in datasets]
@@ -115,7 +118,7 @@ class EvaluationService(CommonService):
except Exception as e:
logging.error(f"Error listing datasets: {e}")
return {"total": 0, "datasets": []}
-
+
@classmethod
def update_dataset(cls, dataset_id: str, **kwargs) -> bool:
"""Update dataset"""
@@ -127,7 +130,7 @@ class EvaluationService(CommonService):
except Exception as e:
logging.error(f"Error updating dataset {dataset_id}: {e}")
return False
-
+
@classmethod
def delete_dataset(cls, dataset_id: str) -> bool:
"""Soft delete dataset"""
@@ -139,18 +142,18 @@ class EvaluationService(CommonService):
except Exception as e:
logging.error(f"Error deleting dataset {dataset_id}: {e}")
return False
-
+
# ==================== Test Case Management ====================
-
+
@classmethod
- def add_test_case(cls, dataset_id: str, question: str,
+ def add_test_case(cls, dataset_id: str, question: str,
reference_answer: Optional[str] = None,
relevant_doc_ids: Optional[List[str]] = None,
relevant_chunk_ids: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None) -> Tuple[bool, str]:
"""
Add a test case to a dataset.
-
+
Args:
dataset_id: Dataset ID
question: Test question
@@ -158,7 +161,7 @@ class EvaluationService(CommonService):
relevant_doc_ids: Optional list of relevant document IDs
relevant_chunk_ids: Optional list of relevant chunk IDs
metadata: Optional additional metadata
-
+
Returns:
(success, case_id or error_message)
"""
@@ -174,15 +177,15 @@ class EvaluationService(CommonService):
"metadata": metadata,
"create_time": current_timestamp()
}
-
+
if not EvaluationCase.create(**case):
return False, "Failed to create test case"
-
+
return True, case_id
except Exception as e:
logging.error(f"Error adding test case: {e}")
return False, str(e)
-
+
@classmethod
def get_test_cases(cls, dataset_id: str) -> List[Dict[str, Any]]:
"""Get all test cases for a dataset"""
@@ -190,12 +193,12 @@ class EvaluationService(CommonService):
cases = EvaluationCase.select().where(
EvaluationCase.dataset_id == dataset_id
).order_by(EvaluationCase.create_time)
-
+
return [c.to_dict() for c in cases]
except Exception as e:
logging.error(f"Error getting test cases for dataset {dataset_id}: {e}")
return []
-
+
@classmethod
def delete_test_case(cls, case_id: str) -> bool:
"""Delete a test case"""
@@ -206,22 +209,22 @@ class EvaluationService(CommonService):
except Exception as e:
logging.error(f"Error deleting test case {case_id}: {e}")
return False
-
+
@classmethod
def import_test_cases(cls, dataset_id: str, cases: List[Dict[str, Any]]) -> Tuple[int, int]:
"""
Bulk import test cases from a list.
-
+
Args:
dataset_id: Dataset ID
cases: List of test case dictionaries
-
+
Returns:
(success_count, failure_count)
"""
success_count = 0
failure_count = 0
-
+
for case_data in cases:
success, _ = cls.add_test_case(
dataset_id=dataset_id,
@@ -231,28 +234,28 @@ class EvaluationService(CommonService):
relevant_chunk_ids=case_data.get("relevant_chunk_ids"),
metadata=case_data.get("metadata")
)
-
+
if success:
success_count += 1
else:
failure_count += 1
-
+
return success_count, failure_count
-
+
# ==================== Evaluation Execution ====================
-
+
@classmethod
- def start_evaluation(cls, dataset_id: str, dialog_id: str,
+ def start_evaluation(cls, dataset_id: str, dialog_id: str,
user_id: str, name: Optional[str] = None) -> Tuple[bool, str]:
"""
Start an evaluation run.
-
+
Args:
dataset_id: Dataset ID
dialog_id: Dialog configuration to evaluate
user_id: User ID who starts the run
name: Optional run name
-
+
Returns:
(success, run_id or error_message)
"""
@@ -261,12 +264,12 @@ class EvaluationService(CommonService):
success, dialog = DialogService.get_by_id(dialog_id)
if not success:
return False, "Dialog not found"
-
+
# Create evaluation run
run_id = get_uuid()
if not name:
name = f"Evaluation Run {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
-
+
run = {
"id": run_id,
"dataset_id": dataset_id,
@@ -279,92 +282,128 @@ class EvaluationService(CommonService):
"create_time": current_timestamp(),
"complete_time": None
}
-
+
if not EvaluationRun.create(**run):
return False, "Failed to create evaluation run"
-
+
# Execute evaluation asynchronously (in production, use task queue)
# For now, we'll execute synchronously
cls._execute_evaluation(run_id, dataset_id, dialog)
-
+
return True, run_id
except Exception as e:
logging.error(f"Error starting evaluation: {e}")
return False, str(e)
-
+
@classmethod
def _execute_evaluation(cls, run_id: str, dataset_id: str, dialog: Any):
"""
Execute evaluation for all test cases.
-
+
This method runs the RAG pipeline for each test case and computes metrics.
"""
try:
# Get all test cases
test_cases = cls.get_test_cases(dataset_id)
-
+
if not test_cases:
EvaluationRun.update(
status="FAILED",
complete_time=current_timestamp()
).where(EvaluationRun.id == run_id).execute()
return
-
+
# Execute each test case
results = []
for case in test_cases:
result = cls._evaluate_single_case(run_id, case, dialog)
if result:
results.append(result)
-
+
# Compute summary metrics
metrics_summary = cls._compute_summary_metrics(results)
-
+
# Update run status
EvaluationRun.update(
status="COMPLETED",
metrics_summary=metrics_summary,
complete_time=current_timestamp()
).where(EvaluationRun.id == run_id).execute()
-
+
except Exception as e:
logging.error(f"Error executing evaluation {run_id}: {e}")
EvaluationRun.update(
status="FAILED",
complete_time=current_timestamp()
).where(EvaluationRun.id == run_id).execute()
-
+
@classmethod
- def _evaluate_single_case(cls, run_id: str, case: Dict[str, Any],
+ def _evaluate_single_case(cls, run_id: str, case: Dict[str, Any],
dialog: Any) -> Optional[Dict[str, Any]]:
"""
Evaluate a single test case.
-
+
Args:
run_id: Evaluation run ID
case: Test case dictionary
dialog: Dialog configuration
-
+
Returns:
Result dictionary or None if failed
"""
try:
# Prepare messages
messages = [{"role": "user", "content": case["question"]}]
-
+
# Execute RAG pipeline
start_time = timer()
answer = ""
retrieved_chunks = []
-
+
+
+ def _sync_from_async_gen(async_gen):
+ result_queue: queue.Queue = queue.Queue()
+
+ def runner():
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+
+ async def consume():
+ try:
+ async for item in async_gen:
+ result_queue.put(item)
+ except Exception as e:
+ result_queue.put(e)
+ finally:
+ result_queue.put(StopIteration)
+
+ loop.run_until_complete(consume())
+ loop.close()
+
+ threading.Thread(target=runner, daemon=True).start()
+
+ while True:
+ item = result_queue.get()
+ if item is StopIteration:
+ break
+ if isinstance(item, Exception):
+ raise item
+ yield item
+
+
+ def chat(dialog, messages, stream=True, **kwargs):
+ from api.db.services.dialog_service import async_chat
+
+ return _sync_from_async_gen(async_chat(dialog, messages, stream=stream, **kwargs))
+
for ans in chat(dialog, messages, stream=False):
if isinstance(ans, dict):
answer = ans.get("answer", "")
retrieved_chunks = ans.get("reference", {}).get("chunks", [])
break
-
+
execution_time = timer() - start_time
-
+
# Compute metrics
metrics = cls._compute_metrics(
question=case["question"],
@@ -374,7 +413,7 @@ class EvaluationService(CommonService):
relevant_chunk_ids=case.get("relevant_chunk_ids"),
dialog=dialog
)
-
+
# Save result
result_id = get_uuid()
result = {
@@ -388,14 +427,14 @@ class EvaluationService(CommonService):
"token_usage": None, # TODO: Track token usage
"create_time": current_timestamp()
}
-
+
EvaluationResult.create(**result)
-
+
return result
except Exception as e:
logging.error(f"Error evaluating case {case.get('id')}: {e}")
return None
-
+
@classmethod
def _compute_metrics(cls, question: str, generated_answer: str,
reference_answer: Optional[str],
@@ -404,69 +443,69 @@ class EvaluationService(CommonService):
dialog: Any) -> Dict[str, float]:
"""
Compute evaluation metrics for a single test case.
-
+
Returns:
Dictionary of metric names to values
"""
metrics = {}
-
+
# Retrieval metrics (if ground truth chunks provided)
if relevant_chunk_ids:
retrieved_ids = [c.get("chunk_id") for c in retrieved_chunks]
metrics.update(cls._compute_retrieval_metrics(retrieved_ids, relevant_chunk_ids))
-
+
# Generation metrics
if generated_answer:
# Basic metrics
metrics["answer_length"] = len(generated_answer)
metrics["has_answer"] = 1.0 if generated_answer.strip() else 0.0
-
+
# TODO: Implement advanced metrics using LLM-as-judge
# - Faithfulness (hallucination detection)
# - Answer relevance
# - Context relevance
# - Semantic similarity (if reference answer provided)
-
+
return metrics
-
+
@classmethod
- def _compute_retrieval_metrics(cls, retrieved_ids: List[str],
+ def _compute_retrieval_metrics(cls, retrieved_ids: List[str],
relevant_ids: List[str]) -> Dict[str, float]:
"""
Compute retrieval metrics.
-
+
Args:
retrieved_ids: List of retrieved chunk IDs
relevant_ids: List of relevant chunk IDs (ground truth)
-
+
Returns:
Dictionary of retrieval metrics
"""
if not relevant_ids:
return {}
-
+
retrieved_set = set(retrieved_ids)
relevant_set = set(relevant_ids)
-
+
# Precision: proportion of retrieved that are relevant
precision = len(retrieved_set & relevant_set) / len(retrieved_set) if retrieved_set else 0.0
-
+
# Recall: proportion of relevant that were retrieved
recall = len(retrieved_set & relevant_set) / len(relevant_set) if relevant_set else 0.0
-
+
# F1 score
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
-
+
# Hit rate: whether any relevant chunk was retrieved
hit_rate = 1.0 if (retrieved_set & relevant_set) else 0.0
-
+
# MRR (Mean Reciprocal Rank): position of first relevant chunk
mrr = 0.0
for i, chunk_id in enumerate(retrieved_ids, 1):
if chunk_id in relevant_set:
mrr = 1.0 / i
break
-
+
return {
"precision": precision,
"recall": recall,
@@ -474,45 +513,45 @@ class EvaluationService(CommonService):
"hit_rate": hit_rate,
"mrr": mrr
}
-
+
@classmethod
def _compute_summary_metrics(cls, results: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Compute summary metrics across all test cases.
-
+
Args:
results: List of result dictionaries
-
+
Returns:
Summary metrics dictionary
"""
if not results:
return {}
-
+
# Aggregate metrics
metric_sums = {}
metric_counts = {}
-
+
for result in results:
metrics = result.get("metrics", {})
for key, value in metrics.items():
if isinstance(value, (int, float)):
metric_sums[key] = metric_sums.get(key, 0) + value
metric_counts[key] = metric_counts.get(key, 0) + 1
-
+
# Compute averages
summary = {
"total_cases": len(results),
"avg_execution_time": sum(r.get("execution_time", 0) for r in results) / len(results)
}
-
+
for key in metric_sums:
summary[f"avg_{key}"] = metric_sums[key] / metric_counts[key]
-
+
return summary
-
+
# ==================== Results & Analysis ====================
-
+
@classmethod
def get_run_results(cls, run_id: str) -> Dict[str, Any]:
"""Get results for an evaluation run"""
@@ -520,11 +559,11 @@ class EvaluationService(CommonService):
run = EvaluationRun.get_by_id(run_id)
if not run:
return {}
-
+
results = EvaluationResult.select().where(
EvaluationResult.run_id == run_id
).order_by(EvaluationResult.create_time)
-
+
return {
"run": run.to_dict(),
"results": [r.to_dict() for r in results]
@@ -532,15 +571,15 @@ class EvaluationService(CommonService):
except Exception as e:
logging.error(f"Error getting run results {run_id}: {e}")
return {}
-
+
@classmethod
def get_recommendations(cls, run_id: str) -> List[Dict[str, Any]]:
"""
Analyze evaluation results and provide configuration recommendations.
-
+
Args:
run_id: Evaluation run ID
-
+
Returns:
List of recommendation dictionaries
"""
@@ -548,10 +587,10 @@ class EvaluationService(CommonService):
run = EvaluationRun.get_by_id(run_id)
if not run or not run.metrics_summary:
return []
-
+
metrics = run.metrics_summary
recommendations = []
-
+
# Low precision: retrieving irrelevant chunks
if metrics.get("avg_precision", 1.0) < 0.7:
recommendations.append({
@@ -564,7 +603,7 @@ class EvaluationService(CommonService):
"Reduce top_k to return fewer chunks"
]
})
-
+
# Low recall: missing relevant chunks
if metrics.get("avg_recall", 1.0) < 0.7:
recommendations.append({
@@ -578,7 +617,7 @@ class EvaluationService(CommonService):
"Check chunk size - may be too large or too small"
]
})
-
+
# Slow response time
if metrics.get("avg_execution_time", 0) > 5.0:
recommendations.append({
@@ -591,7 +630,7 @@ class EvaluationService(CommonService):
"Consider caching frequently asked questions"
]
})
-
+
return recommendations
except Exception as e:
logging.error(f"Error generating recommendations for run {run_id}: {e}")
diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py
index 6a63713ec..86356a7a7 100644
--- a/api/db/services/llm_service.py
+++ b/api/db/services/llm_service.py
@@ -16,15 +16,17 @@
import asyncio
import inspect
import logging
+import queue
import re
import threading
-from common.token_utils import num_tokens_from_string
from functools import partial
from typing import Generator
-from common.constants import LLMType
+
from api.db.db_models import LLM
from api.db.services.common_service import CommonService
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
+from common.constants import LLMType
+from common.token_utils import num_tokens_from_string
class LLMService(CommonService):
@@ -33,6 +35,7 @@ class LLMService(CommonService):
def get_init_tenant_llm(user_id):
from common import settings
+
tenant_llm = []
model_configs = {
@@ -193,7 +196,7 @@ class LLMBundle(LLM4Tenant):
generation = self.langfuse.start_generation(
trace_context=self.trace_context,
name="stream_transcription",
- metadata={"model": self.llm_name}
+ metadata={"model": self.llm_name},
)
final_text = ""
used_tokens = 0
@@ -217,32 +220,34 @@ class LLMBundle(LLM4Tenant):
if self.langfuse:
generation.update(
output={"output": final_text},
- usage_details={"total_tokens": used_tokens}
+ usage_details={"total_tokens": used_tokens},
)
generation.end()
return
if self.langfuse:
- generation = self.langfuse.start_generation(trace_context=self.trace_context, name="stream_transcription", metadata={"model": self.llm_name})
- full_text, used_tokens = mdl.transcription(audio)
- if not TenantLLMService.increase_usage(
- self.tenant_id, self.llm_type, used_tokens
- ):
- logging.error(
- f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}"
+ generation = self.langfuse.start_generation(
+ trace_context=self.trace_context,
+ name="stream_transcription",
+ metadata={"model": self.llm_name},
)
+
+ full_text, used_tokens = mdl.transcription(audio)
+ if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
+ logging.error(f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}")
+
if self.langfuse:
generation.update(
output={"output": full_text},
- usage_details={"total_tokens": used_tokens}
+ usage_details={"total_tokens": used_tokens},
)
generation.end()
yield {
"event": "final",
"text": full_text,
- "streaming": False
+ "streaming": False,
}
def tts(self, text: str) -> Generator[bytes, None, None]:
@@ -289,61 +294,79 @@ class LLMBundle(LLM4Tenant):
return kwargs
else:
return {k: v for k, v in kwargs.items() if k in allowed_params}
+
+ def _run_coroutine_sync(self, coro):
+ try:
+ asyncio.get_running_loop()
+ except RuntimeError:
+ return asyncio.run(coro)
+
+ result_queue: queue.Queue = queue.Queue()
+
+ def runner():
+ try:
+ result_queue.put((True, asyncio.run(coro)))
+ except Exception as e:
+ result_queue.put((False, e))
+
+ thread = threading.Thread(target=runner, daemon=True)
+ thread.start()
+ thread.join()
+
+ success, value = result_queue.get_nowait()
+ if success:
+ return value
+ raise value
+
def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str:
- if self.langfuse:
- generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
+ return self._run_coroutine_sync(self.async_chat(system, history, gen_conf, **kwargs))
- chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
- if self.is_tools and self.mdl.is_tools:
- chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
+ def _sync_from_async_stream(self, async_gen_fn, *args, **kwargs):
+ result_queue: queue.Queue = queue.Queue()
- use_kwargs = self._clean_param(chat_partial, **kwargs)
- txt, used_tokens = chat_partial(**use_kwargs)
- txt = self._remove_reasoning_content(txt)
+ def runner():
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
- if not self.verbose_tool_use:
- txt = re.sub(r".*?", "", txt, flags=re.DOTALL)
+ async def consume():
+ try:
+ async for item in async_gen_fn(*args, **kwargs):
+ result_queue.put(item)
+ except Exception as e:
+ result_queue.put(e)
+ finally:
+ result_queue.put(StopIteration)
- if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
- logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
+ loop.run_until_complete(consume())
+ loop.close()
- if self.langfuse:
- generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
- generation.end()
+ threading.Thread(target=runner, daemon=True).start()
- return txt
+ while True:
+ item = result_queue.get()
+ if item is StopIteration:
+ break
+ if isinstance(item, Exception):
+ raise item
+ yield item
def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
- if self.langfuse:
- generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
-
ans = ""
- chat_partial = partial(self.mdl.chat_streamly, system, history, gen_conf)
- total_tokens = 0
- if self.is_tools and self.mdl.is_tools:
- chat_partial = partial(self.mdl.chat_streamly_with_tools, system, history, gen_conf)
- use_kwargs = self._clean_param(chat_partial, **kwargs)
- for txt in chat_partial(**use_kwargs):
+ for txt in self._sync_from_async_stream(self.async_chat_streamly, system, history, gen_conf, **kwargs):
if isinstance(txt, int):
- total_tokens = txt
- if self.langfuse:
- generation.update(output={"output": ans})
- generation.end()
break
if txt.endswith(""):
- ans = ans[: -len("")]
+ ans = txt[: -len("")]
+ continue
if not self.verbose_tool_use:
txt = re.sub(r".*?", "", txt, flags=re.DOTALL)
- ans += txt
+ # cancatination has beend done in async_chat_streamly
+ ans = txt
yield ans
- if total_tokens > 0:
- if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
- logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
-
def _bridge_sync_stream(self, gen):
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()
@@ -352,7 +375,7 @@ class LLMBundle(LLM4Tenant):
try:
for item in gen:
loop.call_soon_threadsafe(queue.put_nowait, item)
- except Exception as e: # pragma: no cover
+ except Exception as e:
loop.call_soon_threadsafe(queue.put_nowait, e)
finally:
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
@@ -361,18 +384,27 @@ class LLMBundle(LLM4Tenant):
return queue
async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
- chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
- if self.is_tools and self.mdl.is_tools and hasattr(self.mdl, "chat_with_tools"):
- chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
+ if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_with_tools"):
+ base_fn = self.mdl.async_chat_with_tools
+ elif hasattr(self.mdl, "async_chat"):
+ base_fn = self.mdl.async_chat
+ else:
+ raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
+ generation = None
+ if self.langfuse:
+ generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
+
+ chat_partial = partial(base_fn, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
- if hasattr(self.mdl, "async_chat_with_tools") and self.is_tools and self.mdl.is_tools:
- txt, used_tokens = await self.mdl.async_chat_with_tools(system, history, gen_conf, **use_kwargs)
- elif hasattr(self.mdl, "async_chat"):
- txt, used_tokens = await self.mdl.async_chat(system, history, gen_conf, **use_kwargs)
- else:
- txt, used_tokens = await asyncio.to_thread(chat_partial, **use_kwargs)
+ try:
+ txt, used_tokens = await chat_partial(**use_kwargs)
+ except Exception as e:
+ if generation:
+ generation.update(output={"error": str(e)})
+ generation.end()
+ raise
txt = self._remove_reasoning_content(txt)
if not self.verbose_tool_use:
@@ -381,49 +413,51 @@ class LLMBundle(LLM4Tenant):
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
+ if generation:
+ generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
+ generation.end()
+
return txt
async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
total_tokens = 0
ans = ""
- if self.is_tools and self.mdl.is_tools:
+ if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_streamly_with_tools"):
stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
- else:
+ elif hasattr(self.mdl, "async_chat_streamly"):
stream_fn = getattr(self.mdl, "async_chat_streamly", None)
+ else:
+ raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
+
+ generation = None
+ if self.langfuse:
+ generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
if stream_fn:
chat_partial = partial(stream_fn, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
- async for txt in chat_partial(**use_kwargs):
- if isinstance(txt, int):
- total_tokens = txt
- break
+ try:
+ async for txt in chat_partial(**use_kwargs):
+ if isinstance(txt, int):
+ total_tokens = txt
+ break
- if txt.endswith(""):
- ans = ans[: -len("")]
+ if txt.endswith(""):
+ ans = ans[: -len("")]
- if not self.verbose_tool_use:
- txt = re.sub(r".*?", "", txt, flags=re.DOTALL)
+ if not self.verbose_tool_use:
+ txt = re.sub(r".*?", "", txt, flags=re.DOTALL)
- ans += txt
- yield ans
+ ans += txt
+ yield ans
+ except Exception as e:
+ if generation:
+ generation.update(output={"error": str(e)})
+ generation.end()
+ raise
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
+ if generation:
+ generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
+ generation.end()
return
-
- chat_partial = partial(self.mdl.chat_streamly_with_tools if (self.is_tools and self.mdl.is_tools) else self.mdl.chat_streamly, system, history, gen_conf)
- use_kwargs = self._clean_param(chat_partial, **kwargs)
- queue = self._bridge_sync_stream(chat_partial(**use_kwargs))
- while True:
- item = await queue.get()
- if item is StopAsyncIteration:
- break
- if isinstance(item, Exception):
- raise item
- if isinstance(item, int):
- total_tokens = item
- break
- yield item
-
- if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
- logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py
index 3ff5311fc..67bf0bb09 100644
--- a/rag/llm/__init__.py
+++ b/rag/llm/__init__.py
@@ -52,6 +52,8 @@ class SupportedLiteLLMProvider(StrEnum):
JiekouAI = "Jiekou.AI"
ZHIPU_AI = "ZHIPU-AI"
MiniMax = "MiniMax"
+ DeerAPI = "DeerAPI"
+ GPUStack = "GPUStack"
FACTORY_DEFAULT_BASE_URL = {
@@ -75,6 +77,7 @@ FACTORY_DEFAULT_BASE_URL = {
SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai",
SupportedLiteLLMProvider.ZHIPU_AI: "https://open.bigmodel.cn/api/paas/v4",
SupportedLiteLLMProvider.MiniMax: "https://api.minimaxi.com/v1",
+ SupportedLiteLLMProvider.DeerAPI: "https://api.deerapi.com/v1",
}
@@ -108,6 +111,8 @@ LITELLM_PROVIDER_PREFIX = {
SupportedLiteLLMProvider.JiekouAI: "openai/",
SupportedLiteLLMProvider.ZHIPU_AI: "openai/",
SupportedLiteLLMProvider.MiniMax: "openai/",
+ SupportedLiteLLMProvider.DeerAPI: "openai/",
+ SupportedLiteLLMProvider.GPUStack: "openai/",
}
ChatModel = globals().get("ChatModel", {})
diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py
index e69ff1868..9f5457224 100644
--- a/rag/llm/chat_model.py
+++ b/rag/llm/chat_model.py
@@ -19,7 +19,6 @@ import logging
import os
import random
import re
-import threading
import time
from abc import ABC
from copy import deepcopy
@@ -78,11 +77,9 @@ class Base(ABC):
self.toolcall_sessions = {}
def _get_delay(self):
- """Calculate retry delay time"""
return self.base_delay * random.uniform(10, 150)
def _classify_error(self, error):
- """Classify error based on error message content"""
error_str = str(error).lower()
keywords_mapping = [
@@ -139,89 +136,7 @@ class Base(ABC):
return gen_conf
- def _bridge_sync_stream(self, gen):
- """Run a sync generator in a thread and yield asynchronously."""
- loop = asyncio.get_running_loop()
- queue: asyncio.Queue = asyncio.Queue()
-
- def worker():
- try:
- for item in gen:
- loop.call_soon_threadsafe(queue.put_nowait, item)
- except Exception as exc: # pragma: no cover - defensive
- loop.call_soon_threadsafe(queue.put_nowait, exc)
- finally:
- loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
-
- threading.Thread(target=worker, daemon=True).start()
- return queue
-
- def _chat(self, history, gen_conf, **kwargs):
- logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
- if self.model_name.lower().find("qwq") >= 0:
- logging.info(f"[INFO] {self.model_name} detected as reasoning model, using _chat_streamly")
-
- final_ans = ""
- tol_token = 0
- for delta, tol in self._chat_streamly(history, gen_conf, with_reasoning=False, **kwargs):
- if delta.startswith("") or delta.endswith(""):
- continue
- final_ans += delta
- tol_token = tol
-
- if len(final_ans.strip()) == 0:
- final_ans = "**ERROR**: Empty response from reasoning model"
-
- return final_ans.strip(), tol_token
-
- if self.model_name.lower().find("qwen3") >= 0:
- kwargs["extra_body"] = {"enable_thinking": False}
-
- response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
-
- if not response.choices or not response.choices[0].message or not response.choices[0].message.content:
- return "", 0
- ans = response.choices[0].message.content.strip()
- if response.choices[0].finish_reason == "length":
- ans = self._length_stop(ans)
- return ans, total_token_count_from_response(response)
-
- def _chat_streamly(self, history, gen_conf, **kwargs):
- logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
- reasoning_start = False
-
- if kwargs.get("stop") or "stop" in gen_conf:
- response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf, stop=kwargs.get("stop"))
- else:
- response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
-
- for resp in response:
- if not resp.choices:
- continue
- if not resp.choices[0].delta.content:
- resp.choices[0].delta.content = ""
- if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
- ans = ""
- if not reasoning_start:
- reasoning_start = True
- ans = ""
- ans += resp.choices[0].delta.reasoning_content + ""
- else:
- reasoning_start = False
- ans = resp.choices[0].delta.content
-
- tol = total_token_count_from_response(resp)
- if not tol:
- tol = num_tokens_from_string(resp.choices[0].delta.content)
-
- if resp.choices[0].finish_reason == "length":
- if is_chinese(ans):
- ans += LENGTH_NOTIFICATION_CN
- else:
- ans += LENGTH_NOTIFICATION_EN
- yield ans, tol
-
- async def _async_chat_stream(self, history, gen_conf, **kwargs):
+ async def _async_chat_streamly(self, history, gen_conf, **kwargs):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
reasoning_start = False
@@ -265,13 +180,19 @@ class Base(ABC):
gen_conf = self._clean_conf(gen_conf)
ans = ""
total_tokens = 0
- try:
- async for delta_ans, tol in self._async_chat_stream(history, gen_conf, **kwargs):
- ans = delta_ans
- total_tokens += tol
- yield delta_ans
- except openai.APIError as e:
- yield ans + "\n**ERROR**: " + str(e)
+
+ for attempt in range(self.max_retries + 1):
+ try:
+ async for delta_ans, tol in self._async_chat_streamly(history, gen_conf, **kwargs):
+ ans = delta_ans
+ total_tokens += tol
+ yield ans
+ except Exception as e:
+ e = await self._exceptions_async(e, attempt)
+ if e:
+ yield e
+ yield total_tokens
+ return
yield total_tokens
@@ -307,7 +228,7 @@ class Base(ABC):
logging.error(f"sync base giving up: {msg}")
return msg
- async def _exceptions_async(self, e, attempt) -> str | None:
+ async def _exceptions_async(self, e, attempt):
logging.exception("OpenAI async completion")
error_code = self._classify_error(e)
if attempt == self.max_retries:
@@ -357,61 +278,6 @@ class Base(ABC):
self.toolcall_session = toolcall_session
self.tools = tools
- def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
- gen_conf = self._clean_conf(gen_conf)
- if system and history and history[0].get("role") != "system":
- history.insert(0, {"role": "system", "content": system})
-
- ans = ""
- tk_count = 0
- hist = deepcopy(history)
- # Implement exponential backoff retry strategy
- for attempt in range(self.max_retries + 1):
- history = hist
- try:
- for _ in range(self.max_rounds + 1):
- logging.info(f"{self.tools=}")
- response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
- tk_count += total_token_count_from_response(response)
- if any([not response.choices, not response.choices[0].message]):
- raise Exception(f"500 response structure error. Response: {response}")
-
- if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
- if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
- ans += "" + response.choices[0].message.reasoning_content + ""
-
- ans += response.choices[0].message.content
- if response.choices[0].finish_reason == "length":
- ans = self._length_stop(ans)
-
- return ans, tk_count
-
- for tool_call in response.choices[0].message.tool_calls:
- logging.info(f"Response {tool_call=}")
- name = tool_call.function.name
- try:
- args = json_repair.loads(tool_call.function.arguments)
- tool_response = self.toolcall_session.tool_call(name, args)
- history = self._append_history(history, tool_call, tool_response)
- ans += self._verbose_tool_use(name, args, tool_response)
- except Exception as e:
- logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
- history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
- ans += self._verbose_tool_use(name, {}, str(e))
-
- logging.warning(f"Exceed max rounds: {self.max_rounds}")
- history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
- response, token_count = self._chat(history, gen_conf)
- ans += response
- tk_count += token_count
- return ans, tk_count
- except Exception as e:
- e = self._exceptions(e, attempt)
- if e:
- return e, tk_count
-
- assert False, "Shouldn't be here."
-
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf)
if system and history and history[0].get("role") != "system":
@@ -466,140 +332,6 @@ class Base(ABC):
assert False, "Shouldn't be here."
- def chat(self, system, history, gen_conf={}, **kwargs):
- if system and history and history[0].get("role") != "system":
- history.insert(0, {"role": "system", "content": system})
- gen_conf = self._clean_conf(gen_conf)
-
- # Implement exponential backoff retry strategy
- for attempt in range(self.max_retries + 1):
- try:
- return self._chat(history, gen_conf, **kwargs)
- except Exception as e:
- e = self._exceptions(e, attempt)
- if e:
- return e, 0
- assert False, "Shouldn't be here."
-
- def _wrap_toolcall_message(self, stream):
- final_tool_calls = {}
-
- for chunk in stream:
- for tool_call in chunk.choices[0].delta.tool_calls or []:
- index = tool_call.index
-
- if index not in final_tool_calls:
- final_tool_calls[index] = tool_call
-
- final_tool_calls[index].function.arguments += tool_call.function.arguments
-
- return final_tool_calls
-
- def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
- gen_conf = self._clean_conf(gen_conf)
- tools = self.tools
- if system and history and history[0].get("role") != "system":
- history.insert(0, {"role": "system", "content": system})
-
- total_tokens = 0
- hist = deepcopy(history)
- # Implement exponential backoff retry strategy
- for attempt in range(self.max_retries + 1):
- history = hist
- try:
- for _ in range(self.max_rounds + 1):
- reasoning_start = False
- logging.info(f"{tools=}")
- response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
- final_tool_calls = {}
- answer = ""
- for resp in response:
- if resp.choices[0].delta.tool_calls:
- for tool_call in resp.choices[0].delta.tool_calls or []:
- index = tool_call.index
-
- if index not in final_tool_calls:
- if not tool_call.function.arguments:
- tool_call.function.arguments = ""
- final_tool_calls[index] = tool_call
- else:
- final_tool_calls[index].function.arguments += tool_call.function.arguments if tool_call.function.arguments else ""
- continue
-
- if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
- raise Exception("500 response structure error.")
-
- if not resp.choices[0].delta.content:
- resp.choices[0].delta.content = ""
-
- if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
- ans = ""
- if not reasoning_start:
- reasoning_start = True
- ans = ""
- ans += resp.choices[0].delta.reasoning_content + ""
- yield ans
- else:
- reasoning_start = False
- answer += resp.choices[0].delta.content
- yield resp.choices[0].delta.content
-
- tol = total_token_count_from_response(resp)
- if not tol:
- total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
- else:
- total_tokens = tol
-
- finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
- if finish_reason == "length":
- yield self._length_stop("")
-
- if answer:
- yield total_tokens
- return
-
- for tool_call in final_tool_calls.values():
- name = tool_call.function.name
- try:
- args = json_repair.loads(tool_call.function.arguments)
- yield self._verbose_tool_use(name, args, "Begin to call...")
- tool_response = self.toolcall_session.tool_call(name, args)
- history = self._append_history(history, tool_call, tool_response)
- yield self._verbose_tool_use(name, args, tool_response)
- except Exception as e:
- logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
- history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
- yield self._verbose_tool_use(name, {}, str(e))
-
- logging.warning(f"Exceed max rounds: {self.max_rounds}")
- history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
- response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
- for resp in response:
- if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
- raise Exception("500 response structure error.")
- if not resp.choices[0].delta.content:
- resp.choices[0].delta.content = ""
- continue
- tol = total_token_count_from_response(resp)
- if not tol:
- total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
- else:
- total_tokens = tol
- answer += resp.choices[0].delta.content
- yield resp.choices[0].delta.content
-
- yield total_tokens
- return
-
- except Exception as e:
- e = self._exceptions(e, attempt)
- if e:
- yield e
- yield total_tokens
- return
-
- assert False, "Shouldn't be here."
-
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf)
tools = self.tools
@@ -715,9 +447,10 @@ class Base(ABC):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwq") >= 0:
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using async_chat_streamly")
+
final_ans = ""
tol_token = 0
- async for delta, tol in self._async_chat_stream(history, gen_conf, with_reasoning=False, **kwargs):
+ async for delta, tol in self._async_chat_streamly(history, gen_conf, with_reasoning=False, **kwargs):
if delta.startswith("") or delta.endswith(""):
continue
final_ans += delta
@@ -754,57 +487,6 @@ class Base(ABC):
return e, 0
assert False, "Shouldn't be here."
- def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
- if system and history and history[0].get("role") != "system":
- history.insert(0, {"role": "system", "content": system})
- gen_conf = self._clean_conf(gen_conf)
- ans = ""
- total_tokens = 0
- try:
- for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs):
- yield delta_ans
- total_tokens += tol
- except openai.APIError as e:
- yield ans + "\n**ERROR**: " + str(e)
-
- yield total_tokens
-
- def _calculate_dynamic_ctx(self, history):
- """Calculate dynamic context window size"""
-
- def count_tokens(text):
- """Calculate token count for text"""
- # Simple calculation: 1 token per ASCII character
- # 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
- total = 0
- for char in text:
- if ord(char) < 128: # ASCII characters
- total += 1
- else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.)
- total += 2
- return total
-
- # Calculate total tokens for all messages
- total_tokens = 0
- for message in history:
- content = message.get("content", "")
- # Calculate content tokens
- content_tokens = count_tokens(content)
- # Add role marker token overhead
- role_tokens = 4
- total_tokens += content_tokens + role_tokens
-
- # Apply 1.2x buffer ratio
- total_tokens_with_buffer = int(total_tokens * 1.2)
-
- if total_tokens_with_buffer <= 8192:
- ctx_size = 8192
- else:
- ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
- ctx_size = ctx_multiplier * 8192
-
- return ctx_size
-
class GptTurbo(Base):
_FACTORY_NAME = "OpenAI"
@@ -1504,16 +1186,6 @@ class GoogleChat(Base):
yield total_tokens
-class GPUStackChat(Base):
- _FACTORY_NAME = "GPUStack"
-
- def __init__(self, key=None, model_name="", base_url="", **kwargs):
- if not base_url:
- raise ValueError("Local llm url cannot be None")
- base_url = urljoin(base_url, "v1")
- super().__init__(key, model_name, base_url, **kwargs)
-
-
class TokenPonyChat(Base):
_FACTORY_NAME = "TokenPony"
@@ -1523,15 +1195,6 @@ class TokenPonyChat(Base):
super().__init__(key, model_name, base_url, **kwargs)
-class DeerAPIChat(Base):
- _FACTORY_NAME = "DeerAPI"
-
- def __init__(self, key, model_name, base_url="https://api.deerapi.com/v1", **kwargs):
- if not base_url:
- base_url = "https://api.deerapi.com/v1"
- super().__init__(key, model_name, base_url, **kwargs)
-
-
class LiteLLMBase(ABC):
_FACTORY_NAME = [
"Tongyi-Qianwen",
@@ -1562,6 +1225,8 @@ class LiteLLMBase(ABC):
"Jiekou.AI",
"ZHIPU-AI",
"MiniMax",
+ "DeerAPI",
+ "GPUStack",
]
def __init__(self, key, model_name, base_url=None, **kwargs):
@@ -1589,11 +1254,9 @@ class LiteLLMBase(ABC):
self.provider_order = json.loads(key).get("provider_order", "")
def _get_delay(self):
- """Calculate retry delay time"""
return self.base_delay * random.uniform(10, 150)
def _classify_error(self, error):
- """Classify error based on error message content"""
error_str = str(error).lower()
keywords_mapping = [
@@ -1619,72 +1282,6 @@ class LiteLLMBase(ABC):
del gen_conf["max_tokens"]
return gen_conf
- def _chat(self, history, gen_conf, **kwargs):
- logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
- if self.model_name.lower().find("qwen3") >= 0:
- kwargs["extra_body"] = {"enable_thinking": False}
-
- completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf)
- response = litellm.completion(
- **completion_args,
- drop_params=True,
- timeout=self.timeout,
- )
- # response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
- if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
- return "", 0
- ans = response.choices[0].message.content.strip()
- if response.choices[0].finish_reason == "length":
- ans = self._length_stop(ans)
-
- return ans, total_token_count_from_response(response)
-
- def _chat_streamly(self, history, gen_conf, **kwargs):
- logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
- gen_conf = self._clean_conf(gen_conf)
- reasoning_start = False
-
- completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf)
- stop = kwargs.get("stop")
- if stop:
- completion_args["stop"] = stop
- response = litellm.completion(
- **completion_args,
- drop_params=True,
- timeout=self.timeout,
- )
-
- for resp in response:
- if not hasattr(resp, "choices") or not resp.choices:
- continue
-
- delta = resp.choices[0].delta
- if not hasattr(delta, "content") or delta.content is None:
- delta.content = ""
-
- if kwargs.get("with_reasoning", True) and hasattr(delta, "reasoning_content") and delta.reasoning_content:
- ans = ""
- if not reasoning_start:
- reasoning_start = True
- ans = ""
- ans += delta.reasoning_content + ""
- else:
- reasoning_start = False
- ans = delta.content
-
- tol = total_token_count_from_response(resp)
- if not tol:
- tol = num_tokens_from_string(delta.content)
-
- finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
- if finish_reason == "length":
- if is_chinese(ans):
- ans += LENGTH_NOTIFICATION_CN
- else:
- ans += LENGTH_NOTIFICATION_EN
-
- yield ans, tol
-
async def async_chat(self, system, history, gen_conf, **kwargs):
hist = list(history) if history else []
if system:
@@ -1795,22 +1392,7 @@ class LiteLLMBase(ABC):
def _should_retry(self, error_code: str) -> bool:
return error_code in self._retryable_errors
- def _exceptions(self, e, attempt) -> str | None:
- logging.exception("OpenAI chat_with_tools")
- # Classify the error
- error_code = self._classify_error(e)
- if attempt == self.max_retries:
- error_code = LLMErrorCode.ERROR_MAX_RETRIES
-
- if self._should_retry(error_code):
- delay = self._get_delay()
- logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
- time.sleep(delay)
- return None
-
- return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
-
- async def _exceptions_async(self, e, attempt) -> str | None:
+ async def _exceptions_async(self, e, attempt):
logging.exception("LiteLLMBase async completion")
error_code = self._classify_error(e)
if attempt == self.max_retries:
@@ -1859,71 +1441,7 @@ class LiteLLMBase(ABC):
self.toolcall_session = toolcall_session
self.tools = tools
- def _construct_completion_args(self, history, stream: bool, tools: bool, **kwargs):
- completion_args = {
- "model": self.model_name,
- "messages": history,
- "api_key": self.api_key,
- "num_retries": self.max_retries,
- **kwargs,
- }
- if stream:
- completion_args.update(
- {
- "stream": stream,
- }
- )
- if tools and self.tools:
- completion_args.update(
- {
- "tools": self.tools,
- "tool_choice": "auto",
- }
- )
- if self.provider in FACTORY_DEFAULT_BASE_URL:
- completion_args.update({"api_base": self.base_url})
- elif self.provider == SupportedLiteLLMProvider.Bedrock:
- completion_args.pop("api_key", None)
- completion_args.pop("api_base", None)
- completion_args.update(
- {
- "aws_access_key_id": self.bedrock_ak,
- "aws_secret_access_key": self.bedrock_sk,
- "aws_region_name": self.bedrock_region,
- }
- )
-
- if self.provider == SupportedLiteLLMProvider.OpenRouter:
- if self.provider_order:
-
- def _to_order_list(x):
- if x is None:
- return []
- if isinstance(x, str):
- return [s.strip() for s in x.split(",") if s.strip()]
- if isinstance(x, (list, tuple)):
- return [str(s).strip() for s in x if str(s).strip()]
- return []
-
- extra_body = {}
- provider_cfg = {}
- provider_order = _to_order_list(self.provider_order)
- provider_cfg["order"] = provider_order
- provider_cfg["allow_fallbacks"] = False
- extra_body["provider"] = provider_cfg
- completion_args.update({"extra_body": extra_body})
-
- # Ollama deployments commonly sit behind a reverse proxy that enforces
- # Bearer auth. Ensure the Authorization header is set when an API key
- # is provided, while respecting any user-supplied headers. #11350
- extra_headers = deepcopy(completion_args.get("extra_headers") or {})
- if self.provider == SupportedLiteLLMProvider.Ollama and self.api_key and "Authorization" not in extra_headers:
- extra_headers["Authorization"] = f"Bearer {self.api_key}"
- if extra_headers:
- completion_args["extra_headers"] = extra_headers
- return completion_args
-
- def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
+ async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf)
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
@@ -1931,16 +1449,14 @@ class LiteLLMBase(ABC):
ans = ""
tk_count = 0
hist = deepcopy(history)
-
- # Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
- history = deepcopy(hist) # deepcopy is required here
+ history = deepcopy(hist)
try:
for _ in range(self.max_rounds + 1):
logging.info(f"{self.tools=}")
completion_args = self._construct_completion_args(history=history, stream=False, tools=True, **gen_conf)
- response = litellm.completion(
+ response = await litellm.acompletion(
**completion_args,
drop_params=True,
timeout=self.timeout,
@@ -1966,7 +1482,7 @@ class LiteLLMBase(ABC):
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
- tool_response = self.toolcall_session.tool_call(name, args)
+ tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
history = self._append_history(history, tool_call, tool_response)
ans += self._verbose_tool_use(name, args, tool_response)
except Exception as e:
@@ -1977,49 +1493,19 @@ class LiteLLMBase(ABC):
logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
- response, token_count = self._chat(history, gen_conf)
+ response, token_count = await self.async_chat("", history, gen_conf)
ans += response
tk_count += token_count
return ans, tk_count
except Exception as e:
- e = self._exceptions(e, attempt)
+ e = await self._exceptions_async(e, attempt)
if e:
return e, tk_count
assert False, "Shouldn't be here."
- def chat(self, system, history, gen_conf={}, **kwargs):
- if system and history and history[0].get("role") != "system":
- history.insert(0, {"role": "system", "content": system})
- gen_conf = self._clean_conf(gen_conf)
-
- # Implement exponential backoff retry strategy
- for attempt in range(self.max_retries + 1):
- try:
- response = self._chat(history, gen_conf, **kwargs)
- return response
- except Exception as e:
- e = self._exceptions(e, attempt)
- if e:
- return e, 0
- assert False, "Shouldn't be here."
-
- def _wrap_toolcall_message(self, stream):
- final_tool_calls = {}
-
- for chunk in stream:
- for tool_call in chunk.choices[0].delta.tool_calls or []:
- index = tool_call.index
-
- if index not in final_tool_calls:
- final_tool_calls[index] = tool_call
-
- final_tool_calls[index].function.arguments += tool_call.function.arguments
-
- return final_tool_calls
-
- def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
+ async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf)
tools = self.tools
if system and history and history[0].get("role") != "system":
@@ -2028,16 +1514,15 @@ class LiteLLMBase(ABC):
total_tokens = 0
hist = deepcopy(history)
- # Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
- history = deepcopy(hist) # deepcopy is required here
+ history = deepcopy(hist)
try:
for _ in range(self.max_rounds + 1):
reasoning_start = False
logging.info(f"{tools=}")
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
- response = litellm.completion(
+ response = await litellm.acompletion(
**completion_args,
drop_params=True,
timeout=self.timeout,
@@ -2046,7 +1531,7 @@ class LiteLLMBase(ABC):
final_tool_calls = {}
answer = ""
- for resp in response:
+ async for resp in response:
if not hasattr(resp, "choices") or not resp.choices:
continue
@@ -2082,7 +1567,7 @@ class LiteLLMBase(ABC):
if not tol:
total_tokens += num_tokens_from_string(delta.content)
else:
- total_tokens += tol
+ total_tokens = tol
finish_reason = getattr(resp.choices[0], "finish_reason", "")
if finish_reason == "length":
@@ -2097,31 +1582,25 @@ class LiteLLMBase(ABC):
try:
args = json_repair.loads(tool_call.function.arguments)
yield self._verbose_tool_use(name, args, "Begin to call...")
- tool_response = self.toolcall_session.tool_call(name, args)
+ tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
history = self._append_history(history, tool_call, tool_response)
yield self._verbose_tool_use(name, args, tool_response)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
- history.append(
- {
- "role": "tool",
- "tool_call_id": tool_call.id,
- "content": f"Tool call error: \n{tool_call}\nException:\n{str(e)}",
- }
- )
+ history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
yield self._verbose_tool_use(name, {}, str(e))
logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
- response = litellm.completion(
+ response = await litellm.acompletion(
**completion_args,
drop_params=True,
timeout=self.timeout,
)
- for resp in response:
+ async for resp in response:
if not hasattr(resp, "choices") or not resp.choices:
continue
delta = resp.choices[0].delta
@@ -2131,14 +1610,14 @@ class LiteLLMBase(ABC):
if not tol:
total_tokens += num_tokens_from_string(delta.content)
else:
- total_tokens += tol
+ total_tokens = tol
yield delta.content
yield total_tokens
return
except Exception as e:
- e = self._exceptions(e, attempt)
+ e = await self._exceptions_async(e, attempt)
if e:
yield e
yield total_tokens
@@ -2146,53 +1625,71 @@ class LiteLLMBase(ABC):
assert False, "Shouldn't be here."
- def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
- if system and history and history[0].get("role") != "system":
- history.insert(0, {"role": "system", "content": system})
- gen_conf = self._clean_conf(gen_conf)
- ans = ""
- total_tokens = 0
- try:
- for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs):
- yield delta_ans
- total_tokens += tol
- except openai.APIError as e:
- yield ans + "\n**ERROR**: " + str(e)
+ def _construct_completion_args(self, history, stream: bool, tools: bool, **kwargs):
+ completion_args = {
+ "model": self.model_name,
+ "messages": history,
+ "api_key": self.api_key,
+ "num_retries": self.max_retries,
+ **kwargs,
+ }
+ if stream:
+ completion_args.update(
+ {
+ "stream": stream,
+ }
+ )
+ if tools and self.tools:
+ completion_args.update(
+ {
+ "tools": self.tools,
+ "tool_choice": "auto",
+ }
+ )
+ if self.provider in FACTORY_DEFAULT_BASE_URL:
+ completion_args.update({"api_base": self.base_url})
+ elif self.provider == SupportedLiteLLMProvider.Bedrock:
+ completion_args.pop("api_key", None)
+ completion_args.pop("api_base", None)
+ completion_args.update(
+ {
+ "aws_access_key_id": self.bedrock_ak,
+ "aws_secret_access_key": self.bedrock_sk,
+ "aws_region_name": self.bedrock_region,
+ }
+ )
+ elif self.provider == SupportedLiteLLMProvider.OpenRouter:
+ if self.provider_order:
- yield total_tokens
+ def _to_order_list(x):
+ if x is None:
+ return []
+ if isinstance(x, str):
+ return [s.strip() for s in x.split(",") if s.strip()]
+ if isinstance(x, (list, tuple)):
+ return [str(s).strip() for s in x if str(s).strip()]
+ return []
- def _calculate_dynamic_ctx(self, history):
- """Calculate dynamic context window size"""
+ extra_body = {}
+ provider_cfg = {}
+ provider_order = _to_order_list(self.provider_order)
+ provider_cfg["order"] = provider_order
+ provider_cfg["allow_fallbacks"] = False
+ extra_body["provider"] = provider_cfg
+ completion_args.update({"extra_body": extra_body})
+ elif self.provider == SupportedLiteLLMProvider.GPUStack:
+ completion_args.update(
+ {
+ "api_base": self.base_url,
+ }
+ )
- def count_tokens(text):
- """Calculate token count for text"""
- # Simple calculation: 1 token per ASCII character
- # 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
- total = 0
- for char in text:
- if ord(char) < 128: # ASCII characters
- total += 1
- else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.)
- total += 2
- return total
-
- # Calculate total tokens for all messages
- total_tokens = 0
- for message in history:
- content = message.get("content", "")
- # Calculate content tokens
- content_tokens = count_tokens(content)
- # Add role marker token overhead
- role_tokens = 4
- total_tokens += content_tokens + role_tokens
-
- # Apply 1.2x buffer ratio
- total_tokens_with_buffer = int(total_tokens * 1.2)
-
- if total_tokens_with_buffer <= 8192:
- ctx_size = 8192
- else:
- ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
- ctx_size = ctx_multiplier * 8192
-
- return ctx_size
+ # Ollama deployments commonly sit behind a reverse proxy that enforces
+ # Bearer auth. Ensure the Authorization header is set when an API key
+ # is provided, while respecting any user-supplied headers. #11350
+ extra_headers = deepcopy(completion_args.get("extra_headers") or {})
+ if self.provider == SupportedLiteLLMProvider.Ollama and self.api_key and "Authorization" not in extra_headers:
+ extra_headers["Authorization"] = f"Bearer {self.api_key}"
+ if extra_headers:
+ completion_args["extra_headers"] = extra_headers
+ return completion_args