Refa: cleanup synchronous functions in chat_model and implement synchronization for conversation and dialog chats (#11779)

### What problem does this PR solve?

Cleanup synchronous functions in chat_model and implement
synchronization for conversation and dialog chats.

### Type of change

- [x] Refactoring
- [x] Performance Improvement
This commit is contained in:
Yongteng Lei
2025-12-08 09:43:03 +08:00
committed by GitHub
parent 9b8971a9de
commit 51ec708c58
10 changed files with 421 additions and 843 deletions

View File

@ -23,7 +23,7 @@ from quart import Response, request
from api.apps import current_user, login_required from api.apps import current_user, login_required
from api.db.db_models import APIToken from api.db.db_models import APIToken
from api.db.services.conversation_service import ConversationService, structure_answer from api.db.services.conversation_service import ConversationService, structure_answer
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService from api.db.services.search_service import SearchService
from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.tenant_llm_service import TenantLLMService
@ -218,10 +218,10 @@ async def completion():
dia.llm_setting = chat_model_config dia.llm_setting = chat_model_config
is_embedded = bool(chat_model_id) is_embedded = bool(chat_model_id)
def stream(): async def stream():
nonlocal dia, msg, req, conv nonlocal dia, msg, req, conv
try: try:
for ans in chat(dia, msg, True, **req): async for ans in async_chat(dia, msg, True, **req):
ans = structure_answer(conv, ans, message_id, conv.id) ans = structure_answer(conv, ans, message_id, conv.id)
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
if not is_embedded: if not is_embedded:
@ -241,7 +241,7 @@ async def completion():
else: else:
answer = None answer = None
for ans in chat(dia, msg, **req): async for ans in async_chat(dia, msg, **req):
answer = structure_answer(conv, ans, message_id, conv.id) answer = structure_answer(conv, ans, message_id, conv.id)
if not is_embedded: if not is_embedded:
ConversationService.update_by_id(conv.id, conv.to_dict()) ConversationService.update_by_id(conv.id, conv.to_dict())
@ -406,10 +406,10 @@ async def ask_about():
if search_app: if search_app:
search_config = search_app.get("search_config", {}) search_config = search_app.get("search_config", {})
def stream(): async def stream():
nonlocal req, uid nonlocal req, uid
try: try:
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config): async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e: except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"

View File

@ -34,8 +34,9 @@ async def set_api_key():
if not all([secret_key, public_key, host]): if not all([secret_key, public_key, host]):
return get_error_data_result(message="Missing required fields") return get_error_data_result(message="Missing required fields")
current_user_id = current_user.id
langfuse_keys = dict( langfuse_keys = dict(
tenant_id=current_user.id, tenant_id=current_user_id,
secret_key=secret_key, secret_key=secret_key,
public_key=public_key, public_key=public_key,
host=host, host=host,
@ -45,23 +46,24 @@ async def set_api_key():
if not langfuse.auth_check(): if not langfuse.auth_check():
return get_error_data_result(message="Invalid Langfuse keys") return get_error_data_result(message="Invalid Langfuse keys")
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user.id) langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user_id)
with DB.atomic(): with DB.atomic():
try: try:
if not langfuse_entry: if not langfuse_entry:
TenantLangfuseService.save(**langfuse_keys) TenantLangfuseService.save(**langfuse_keys)
else: else:
TenantLangfuseService.update_by_tenant(tenant_id=current_user.id, langfuse_keys=langfuse_keys) TenantLangfuseService.update_by_tenant(tenant_id=current_user_id, langfuse_keys=langfuse_keys)
return get_json_result(data=langfuse_keys) return get_json_result(data=langfuse_keys)
except Exception as e: except Exception as e:
server_error_response(e) return server_error_response(e)
@manager.route("/api_key", methods=["GET"]) # noqa: F821 @manager.route("/api_key", methods=["GET"]) # noqa: F821
@login_required @login_required
@validate_request() @validate_request()
def get_api_key(): def get_api_key():
langfuse_entry = TenantLangfuseService.filter_by_tenant_with_info(tenant_id=current_user.id) current_user_id = current_user.id
langfuse_entry = TenantLangfuseService.filter_by_tenant_with_info(tenant_id=current_user_id)
if not langfuse_entry: if not langfuse_entry:
return get_json_result(message="Have not record any Langfuse keys.") return get_json_result(message="Have not record any Langfuse keys.")
@ -72,7 +74,7 @@ def get_api_key():
except langfuse.api.core.api_error.ApiError as api_err: except langfuse.api.core.api_error.ApiError as api_err:
return get_json_result(message=f"Error from Langfuse: {api_err}") return get_json_result(message=f"Error from Langfuse: {api_err}")
except Exception as e: except Exception as e:
server_error_response(e) return server_error_response(e)
langfuse_entry["project_id"] = langfuse.api.projects.get().dict()["data"][0]["id"] langfuse_entry["project_id"] = langfuse.api.projects.get().dict()["data"][0]["id"]
langfuse_entry["project_name"] = langfuse.api.projects.get().dict()["data"][0]["name"] langfuse_entry["project_name"] = langfuse.api.projects.get().dict()["data"][0]["name"]
@ -84,7 +86,8 @@ def get_api_key():
@login_required @login_required
@validate_request() @validate_request()
def delete_api_key(): def delete_api_key():
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user.id) current_user_id = current_user.id
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user_id)
if not langfuse_entry: if not langfuse_entry:
return get_json_result(message="Have not record any Langfuse keys.") return get_json_result(message="Have not record any Langfuse keys.")
@ -93,4 +96,4 @@ def delete_api_key():
TenantLangfuseService.delete_model(langfuse_entry) TenantLangfuseService.delete_model(langfuse_entry)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
server_error_response(e) return server_error_response(e)

View File

@ -74,7 +74,7 @@ async def set_api_key():
assert factory in ChatModel, f"Chat model from {factory} is not supported yet." assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra) mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
try: try:
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50}) m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50})
if m.find("**ERROR**") >= 0: if m.find("**ERROR**") >= 0:
raise Exception(m) raise Exception(m)
chat_passed = True chat_passed = True
@ -217,7 +217,7 @@ async def add_llm():
**extra, **extra,
) )
try: try:
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9}) m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
if not tc and m.find("**ERROR**:") >= 0: if not tc and m.find("**ERROR**:") >= 0:
raise Exception(m) raise Exception(m)
except Exception as e: except Exception as e:

View File

@ -26,9 +26,10 @@ from api.db.db_models import APIToken
from api.db.services.api_service import API4ConversationService from api.db.services.api_service import API4ConversationService
from api.db.services.canvas_service import UserCanvasService, completion_openai from api.db.services.canvas_service import UserCanvasService, completion_openai
from api.db.services.canvas_service import completion as agent_completion from api.db.services.canvas_service import completion as agent_completion
from api.db.services.conversation_service import ConversationService, iframe_completion from api.db.services.conversation_service import ConversationService
from api.db.services.conversation_service import completion as rag_completion from api.db.services.conversation_service import async_iframe_completion as iframe_completion
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap, meta_filter from api.db.services.conversation_service import async_completion as rag_completion
from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap, meta_filter
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
@ -141,7 +142,7 @@ async def chat_completion(tenant_id, chat_id):
return resp return resp
else: else:
answer = None answer = None
for ans in rag_completion(tenant_id, chat_id, **req): async for ans in rag_completion(tenant_id, chat_id, **req):
answer = ans answer = ans
break break
return get_result(data=answer) return get_result(data=answer)
@ -245,7 +246,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
# The value for the usage field on all chunks except for the last one will be null. # The value for the usage field on all chunks except for the last one will be null.
# The usage field on the last chunk contains token usage statistics for the entire request. # The usage field on the last chunk contains token usage statistics for the entire request.
# The choices field on the last chunk will always be an empty array []. # The choices field on the last chunk will always be an empty array [].
def streamed_response_generator(chat_id, dia, msg): async def streamed_response_generator(chat_id, dia, msg):
token_used = 0 token_used = 0
answer_cache = "" answer_cache = ""
reasoning_cache = "" reasoning_cache = ""
@ -274,7 +275,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
} }
try: try:
for ans in chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference): async for ans in async_chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
last_ans = ans last_ans = ans
answer = ans["answer"] answer = ans["answer"]
@ -342,7 +343,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
return resp return resp
else: else:
answer = None answer = None
for ans in chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference): async for ans in async_chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
# focus answer content only # focus answer content only
answer = ans answer = ans
break break
@ -733,10 +734,10 @@ async def ask_about(tenant_id):
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
uid = tenant_id uid = tenant_id
def stream(): async def stream():
nonlocal req, uid nonlocal req, uid
try: try:
for ans in ask(req["question"], req["kb_ids"], uid): async for ans in async_ask(req["question"], req["kb_ids"], uid):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e: except Exception as e:
yield "data:" + json.dumps( yield "data:" + json.dumps(
@ -827,7 +828,7 @@ async def chatbot_completions(dialog_id):
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp return resp
for answer in iframe_completion(dialog_id, **req): async for answer in iframe_completion(dialog_id, **req):
return get_result(data=answer) return get_result(data=answer)
@ -918,10 +919,10 @@ async def ask_about_embedded():
if search_app := SearchService.get_detail(search_id): if search_app := SearchService.get_detail(search_id):
search_config = search_app.get("search_config", {}) search_config = search_app.get("search_config", {})
def stream(): async def stream():
nonlocal req, uid nonlocal req, uid
try: try:
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config): async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e: except Exception as e:
yield "data:" + json.dumps( yield "data:" + json.dumps(

View File

@ -19,7 +19,7 @@ from common.constants import StatusEnum
from api.db.db_models import Conversation, DB from api.db.db_models import Conversation, DB
from api.db.services.api_service import API4ConversationService from api.db.services.api_service import API4ConversationService
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from api.db.services.dialog_service import DialogService, chat from api.db.services.dialog_service import DialogService, async_chat
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
import json import json
@ -89,8 +89,7 @@ def structure_answer(conv, ans, message_id, session_id):
conv.reference[-1] = reference conv.reference[-1] = reference
return ans return ans
async def async_completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
def completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
assert name, "`name` can not be empty." assert name, "`name` can not be empty."
dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value) dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
assert dia, "You do not own the chat." assert dia, "You do not own the chat."
@ -148,7 +147,7 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None
if stream: if stream:
try: try:
for ans in chat(dia, msg, True, **kwargs): async for ans in async_chat(dia, msg, True, **kwargs):
ans = structure_answer(conv, ans, message_id, session_id) ans = structure_answer(conv, ans, message_id, session_id)
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
ConversationService.update_by_id(conv.id, conv.to_dict()) ConversationService.update_by_id(conv.id, conv.to_dict())
@ -160,14 +159,13 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None
else: else:
answer = None answer = None
for ans in chat(dia, msg, False, **kwargs): async for ans in async_chat(dia, msg, False, **kwargs):
answer = structure_answer(conv, ans, message_id, session_id) answer = structure_answer(conv, ans, message_id, session_id)
ConversationService.update_by_id(conv.id, conv.to_dict()) ConversationService.update_by_id(conv.id, conv.to_dict())
break break
yield answer yield answer
async def async_iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs):
def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs):
e, dia = DialogService.get_by_id(dialog_id) e, dia = DialogService.get_by_id(dialog_id)
assert e, "Dialog not found" assert e, "Dialog not found"
if not session_id: if not session_id:
@ -222,7 +220,7 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg
if stream: if stream:
try: try:
for ans in chat(dia, msg, True, **kwargs): async for ans in async_chat(dia, msg, True, **kwargs):
ans = structure_answer(conv, ans, message_id, session_id) ans = structure_answer(conv, ans, message_id, session_id)
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
ensure_ascii=False) + "\n\n" ensure_ascii=False) + "\n\n"
@ -235,7 +233,7 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg
else: else:
answer = None answer = None
for ans in chat(dia, msg, False, **kwargs): async for ans in async_chat(dia, msg, False, **kwargs):
answer = structure_answer(conv, ans, message_id, session_id) answer = structure_answer(conv, ans, message_id, session_id)
API4ConversationService.append_message(conv.id, conv.to_dict()) API4ConversationService.append_message(conv.id, conv.to_dict())
break break

View File

@ -178,7 +178,8 @@ class DialogService(CommonService):
offset += limit offset += limit
return res return res
def chat_solo(dialog, messages, stream=True):
async def async_chat_solo(dialog, messages, stream=True):
attachments = "" attachments = ""
if "files" in messages[-1]: if "files" in messages[-1]:
attachments = "\n\n".join(FileService.get_files(messages[-1]["files"])) attachments = "\n\n".join(FileService.get_files(messages[-1]["files"]))
@ -197,7 +198,8 @@ def chat_solo(dialog, messages, stream=True):
if stream: if stream:
last_ans = "" last_ans = ""
delta_ans = "" delta_ans = ""
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting): answer = ""
async for ans in chat_mdl.async_chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
answer = ans answer = ans
delta_ans = ans[len(last_ans):] delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 16: if num_tokens_from_string(delta_ans) < 16:
@ -208,7 +210,7 @@ def chat_solo(dialog, messages, stream=True):
if delta_ans: if delta_ans:
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()} yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
else: else:
answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting) answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
user_content = msg[-1].get("content", "[content not available]") user_content = msg[-1].get("content", "[content not available]")
logging.debug("User: {}|Assistant: {}".format(user_content, answer)) logging.debug("User: {}|Assistant: {}".format(user_content, answer))
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()} yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
@ -347,13 +349,12 @@ def meta_filter(metas: dict, filters: list[dict], logic: str = "and"):
return [] return []
return list(doc_ids) return list(doc_ids)
async def async_chat(dialog, messages, stream=True, **kwargs):
def chat(dialog, messages, stream=True, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"): if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
for ans in chat_solo(dialog, messages, stream): async for ans in async_chat_solo(dialog, messages, stream):
yield ans yield ans
return None return
chat_start_ts = timer() chat_start_ts = timer()
@ -400,7 +401,7 @@ def chat(dialog, messages, stream=True, **kwargs):
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids) ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
if ans: if ans:
yield ans yield ans
return None return
for p in prompt_config["parameters"]: for p in prompt_config["parameters"]:
if p["key"] == "knowledge": if p["key"] == "knowledge":
@ -508,7 +509,8 @@ def chat(dialog, messages, stream=True, **kwargs):
empty_res = prompt_config["empty_response"] empty_res = prompt_config["empty_response"]
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
"audio_binary": tts(tts_mdl, empty_res)} "audio_binary": tts(tts_mdl, empty_res)}
return {"answer": prompt_config["empty_response"], "reference": kbinfos} yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
return
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges) kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
gen_conf = dialog.llm_setting gen_conf = dialog.llm_setting
@ -612,7 +614,7 @@ def chat(dialog, messages, stream=True, **kwargs):
if stream: if stream:
last_ans = "" last_ans = ""
answer = "" answer = ""
for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf): async for ans in chat_mdl.async_chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
if thought: if thought:
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
answer = ans answer = ans
@ -626,14 +628,14 @@ def chat(dialog, messages, stream=True, **kwargs):
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
yield decorate_answer(thought + answer) yield decorate_answer(thought + answer)
else: else:
answer = chat_mdl.chat(prompt + prompt4citation, msg[1:], gen_conf) answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf)
user_content = msg[-1].get("content", "[content not available]") user_content = msg[-1].get("content", "[content not available]")
logging.debug("User: {}|Assistant: {}".format(user_content, answer)) logging.debug("User: {}|Assistant: {}".format(user_content, answer))
res = decorate_answer(answer) res = decorate_answer(answer)
res["audio_binary"] = tts(tts_mdl, answer) res["audio_binary"] = tts(tts_mdl, answer)
yield res yield res
return None return
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None): def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
@ -805,8 +807,7 @@ def tts(tts_mdl, text):
return None return None
return binascii.hexlify(bin).decode("utf-8") return binascii.hexlify(bin).decode("utf-8")
async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
doc_ids = search_config.get("doc_ids", []) doc_ids = search_config.get("doc_ids", [])
rerank_mdl = None rerank_mdl = None
kb_ids = search_config.get("kb_ids", kb_ids) kb_ids = search_config.get("kb_ids", kb_ids)
@ -880,7 +881,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
return {"answer": answer, "reference": refs} return {"answer": answer, "reference": refs}
answer = "" answer = ""
for ans in chat_mdl.chat_streamly(sys_prompt, msg, {"temperature": 0.1}): async for ans in chat_mdl.async_chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
answer = ans answer = ans
yield {"answer": answer, "reference": {}} yield {"answer": answer, "reference": {}}
yield decorate_answer(answer) yield decorate_answer(answer)

View File

@ -25,14 +25,17 @@ Provides functionality for evaluating RAG system performance including:
- Configuration recommendations - Configuration recommendations
""" """
import asyncio
import logging import logging
import queue
import threading
from typing import List, Dict, Any, Optional, Tuple from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime from datetime import datetime
from timeit import default_timer as timer from timeit import default_timer as timer
from api.db.db_models import EvaluationDataset, EvaluationCase, EvaluationRun, EvaluationResult from api.db.db_models import EvaluationDataset, EvaluationCase, EvaluationRun, EvaluationResult
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from api.db.services.dialog_service import DialogService, chat from api.db.services.dialog_service import DialogService
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from common.time_utils import current_timestamp from common.time_utils import current_timestamp
from common.constants import StatusEnum from common.constants import StatusEnum
@ -357,6 +360,42 @@ class EvaluationService(CommonService):
answer = "" answer = ""
retrieved_chunks = [] retrieved_chunks = []
def _sync_from_async_gen(async_gen):
result_queue: queue.Queue = queue.Queue()
def runner():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
async def consume():
try:
async for item in async_gen:
result_queue.put(item)
except Exception as e:
result_queue.put(e)
finally:
result_queue.put(StopIteration)
loop.run_until_complete(consume())
loop.close()
threading.Thread(target=runner, daemon=True).start()
while True:
item = result_queue.get()
if item is StopIteration:
break
if isinstance(item, Exception):
raise item
yield item
def chat(dialog, messages, stream=True, **kwargs):
from api.db.services.dialog_service import async_chat
return _sync_from_async_gen(async_chat(dialog, messages, stream=stream, **kwargs))
for ans in chat(dialog, messages, stream=False): for ans in chat(dialog, messages, stream=False):
if isinstance(ans, dict): if isinstance(ans, dict):
answer = ans.get("answer", "") answer = ans.get("answer", "")

View File

@ -16,15 +16,17 @@
import asyncio import asyncio
import inspect import inspect
import logging import logging
import queue
import re import re
import threading import threading
from common.token_utils import num_tokens_from_string
from functools import partial from functools import partial
from typing import Generator from typing import Generator
from common.constants import LLMType
from api.db.db_models import LLM from api.db.db_models import LLM
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
from common.constants import LLMType
from common.token_utils import num_tokens_from_string
class LLMService(CommonService): class LLMService(CommonService):
@ -33,6 +35,7 @@ class LLMService(CommonService):
def get_init_tenant_llm(user_id): def get_init_tenant_llm(user_id):
from common import settings from common import settings
tenant_llm = [] tenant_llm = []
model_configs = { model_configs = {
@ -193,7 +196,7 @@ class LLMBundle(LLM4Tenant):
generation = self.langfuse.start_generation( generation = self.langfuse.start_generation(
trace_context=self.trace_context, trace_context=self.trace_context,
name="stream_transcription", name="stream_transcription",
metadata={"model": self.llm_name} metadata={"model": self.llm_name},
) )
final_text = "" final_text = ""
used_tokens = 0 used_tokens = 0
@ -217,32 +220,34 @@ class LLMBundle(LLM4Tenant):
if self.langfuse: if self.langfuse:
generation.update( generation.update(
output={"output": final_text}, output={"output": final_text},
usage_details={"total_tokens": used_tokens} usage_details={"total_tokens": used_tokens},
) )
generation.end() generation.end()
return return
if self.langfuse: if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="stream_transcription", metadata={"model": self.llm_name}) generation = self.langfuse.start_generation(
full_text, used_tokens = mdl.transcription(audio) trace_context=self.trace_context,
if not TenantLLMService.increase_usage( name="stream_transcription",
self.tenant_id, self.llm_type, used_tokens metadata={"model": self.llm_name},
):
logging.error(
f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}"
) )
full_text, used_tokens = mdl.transcription(audio)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
logging.error(f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}")
if self.langfuse: if self.langfuse:
generation.update( generation.update(
output={"output": full_text}, output={"output": full_text},
usage_details={"total_tokens": used_tokens} usage_details={"total_tokens": used_tokens},
) )
generation.end() generation.end()
yield { yield {
"event": "final", "event": "final",
"text": full_text, "text": full_text,
"streaming": False "streaming": False,
} }
def tts(self, text: str) -> Generator[bytes, None, None]: def tts(self, text: str) -> Generator[bytes, None, None]:
@ -289,61 +294,79 @@ class LLMBundle(LLM4Tenant):
return kwargs return kwargs
else: else:
return {k: v for k, v in kwargs.items() if k in allowed_params} return {k: v for k, v in kwargs.items() if k in allowed_params}
def _run_coroutine_sync(self, coro):
try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(coro)
result_queue: queue.Queue = queue.Queue()
def runner():
try:
result_queue.put((True, asyncio.run(coro)))
except Exception as e:
result_queue.put((False, e))
thread = threading.Thread(target=runner, daemon=True)
thread.start()
thread.join()
success, value = result_queue.get_nowait()
if success:
return value
raise value
def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str: def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str:
if self.langfuse: return self._run_coroutine_sync(self.async_chat(system, history, gen_conf, **kwargs))
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs) def _sync_from_async_stream(self, async_gen_fn, *args, **kwargs):
if self.is_tools and self.mdl.is_tools: result_queue: queue.Queue = queue.Queue()
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
use_kwargs = self._clean_param(chat_partial, **kwargs) def runner():
txt, used_tokens = chat_partial(**use_kwargs) loop = asyncio.new_event_loop()
txt = self._remove_reasoning_content(txt) asyncio.set_event_loop(loop)
if not self.verbose_tool_use: async def consume():
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL) try:
async for item in async_gen_fn(*args, **kwargs):
result_queue.put(item)
except Exception as e:
result_queue.put(e)
finally:
result_queue.put(StopIteration)
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name): loop.run_until_complete(consume())
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens)) loop.close()
if self.langfuse: threading.Thread(target=runner, daemon=True).start()
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
generation.end()
return txt while True:
item = result_queue.get()
if item is StopIteration:
break
if isinstance(item, Exception):
raise item
yield item
def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs): def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
ans = "" ans = ""
chat_partial = partial(self.mdl.chat_streamly, system, history, gen_conf) for txt in self._sync_from_async_stream(self.async_chat_streamly, system, history, gen_conf, **kwargs):
total_tokens = 0
if self.is_tools and self.mdl.is_tools:
chat_partial = partial(self.mdl.chat_streamly_with_tools, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
for txt in chat_partial(**use_kwargs):
if isinstance(txt, int): if isinstance(txt, int):
total_tokens = txt
if self.langfuse:
generation.update(output={"output": ans})
generation.end()
break break
if txt.endswith("</think>"): if txt.endswith("</think>"):
ans = ans[: -len("</think>")] ans = txt[: -len("</think>")]
continue
if not self.verbose_tool_use: if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL) txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
ans += txt # cancatination has beend done in async_chat_streamly
ans = txt
yield ans yield ans
if total_tokens > 0:
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
def _bridge_sync_stream(self, gen): def _bridge_sync_stream(self, gen):
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue() queue: asyncio.Queue = asyncio.Queue()
@ -352,7 +375,7 @@ class LLMBundle(LLM4Tenant):
try: try:
for item in gen: for item in gen:
loop.call_soon_threadsafe(queue.put_nowait, item) loop.call_soon_threadsafe(queue.put_nowait, item)
except Exception as e: # pragma: no cover except Exception as e:
loop.call_soon_threadsafe(queue.put_nowait, e) loop.call_soon_threadsafe(queue.put_nowait, e)
finally: finally:
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration) loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
@ -361,18 +384,27 @@ class LLMBundle(LLM4Tenant):
return queue return queue
async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs): async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs) if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_with_tools"):
if self.is_tools and self.mdl.is_tools and hasattr(self.mdl, "chat_with_tools"): base_fn = self.mdl.async_chat_with_tools
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs) elif hasattr(self.mdl, "async_chat"):
base_fn = self.mdl.async_chat
else:
raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
generation = None
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
chat_partial = partial(base_fn, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs) use_kwargs = self._clean_param(chat_partial, **kwargs)
if hasattr(self.mdl, "async_chat_with_tools") and self.is_tools and self.mdl.is_tools: try:
txt, used_tokens = await self.mdl.async_chat_with_tools(system, history, gen_conf, **use_kwargs) txt, used_tokens = await chat_partial(**use_kwargs)
elif hasattr(self.mdl, "async_chat"): except Exception as e:
txt, used_tokens = await self.mdl.async_chat(system, history, gen_conf, **use_kwargs) if generation:
else: generation.update(output={"error": str(e)})
txt, used_tokens = await asyncio.to_thread(chat_partial, **use_kwargs) generation.end()
raise
txt = self._remove_reasoning_content(txt) txt = self._remove_reasoning_content(txt)
if not self.verbose_tool_use: if not self.verbose_tool_use:
@ -381,19 +413,30 @@ class LLMBundle(LLM4Tenant):
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name): if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens)) logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
if generation:
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
generation.end()
return txt return txt
async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs): async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
total_tokens = 0 total_tokens = 0
ans = "" ans = ""
if self.is_tools and self.mdl.is_tools: if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_streamly_with_tools"):
stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None) stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
else: elif hasattr(self.mdl, "async_chat_streamly"):
stream_fn = getattr(self.mdl, "async_chat_streamly", None) stream_fn = getattr(self.mdl, "async_chat_streamly", None)
else:
raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
generation = None
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
if stream_fn: if stream_fn:
chat_partial = partial(stream_fn, system, history, gen_conf) chat_partial = partial(stream_fn, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs) use_kwargs = self._clean_param(chat_partial, **kwargs)
try:
async for txt in chat_partial(**use_kwargs): async for txt in chat_partial(**use_kwargs):
if isinstance(txt, int): if isinstance(txt, int):
total_tokens = txt total_tokens = txt
@ -407,23 +450,14 @@ class LLMBundle(LLM4Tenant):
ans += txt ans += txt
yield ans yield ans
except Exception as e:
if generation:
generation.update(output={"error": str(e)})
generation.end()
raise
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name): if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens)) logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
if generation:
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
generation.end()
return return
chat_partial = partial(self.mdl.chat_streamly_with_tools if (self.is_tools and self.mdl.is_tools) else self.mdl.chat_streamly, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
queue = self._bridge_sync_stream(chat_partial(**use_kwargs))
while True:
item = await queue.get()
if item is StopAsyncIteration:
break
if isinstance(item, Exception):
raise item
if isinstance(item, int):
total_tokens = item
break
yield item
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))

View File

@ -52,6 +52,8 @@ class SupportedLiteLLMProvider(StrEnum):
JiekouAI = "Jiekou.AI" JiekouAI = "Jiekou.AI"
ZHIPU_AI = "ZHIPU-AI" ZHIPU_AI = "ZHIPU-AI"
MiniMax = "MiniMax" MiniMax = "MiniMax"
DeerAPI = "DeerAPI"
GPUStack = "GPUStack"
FACTORY_DEFAULT_BASE_URL = { FACTORY_DEFAULT_BASE_URL = {
@ -75,6 +77,7 @@ FACTORY_DEFAULT_BASE_URL = {
SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai", SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai",
SupportedLiteLLMProvider.ZHIPU_AI: "https://open.bigmodel.cn/api/paas/v4", SupportedLiteLLMProvider.ZHIPU_AI: "https://open.bigmodel.cn/api/paas/v4",
SupportedLiteLLMProvider.MiniMax: "https://api.minimaxi.com/v1", SupportedLiteLLMProvider.MiniMax: "https://api.minimaxi.com/v1",
SupportedLiteLLMProvider.DeerAPI: "https://api.deerapi.com/v1",
} }
@ -108,6 +111,8 @@ LITELLM_PROVIDER_PREFIX = {
SupportedLiteLLMProvider.JiekouAI: "openai/", SupportedLiteLLMProvider.JiekouAI: "openai/",
SupportedLiteLLMProvider.ZHIPU_AI: "openai/", SupportedLiteLLMProvider.ZHIPU_AI: "openai/",
SupportedLiteLLMProvider.MiniMax: "openai/", SupportedLiteLLMProvider.MiniMax: "openai/",
SupportedLiteLLMProvider.DeerAPI: "openai/",
SupportedLiteLLMProvider.GPUStack: "openai/",
} }
ChatModel = globals().get("ChatModel", {}) ChatModel = globals().get("ChatModel", {})

View File

@ -19,7 +19,6 @@ import logging
import os import os
import random import random
import re import re
import threading
import time import time
from abc import ABC from abc import ABC
from copy import deepcopy from copy import deepcopy
@ -78,11 +77,9 @@ class Base(ABC):
self.toolcall_sessions = {} self.toolcall_sessions = {}
def _get_delay(self): def _get_delay(self):
"""Calculate retry delay time"""
return self.base_delay * random.uniform(10, 150) return self.base_delay * random.uniform(10, 150)
def _classify_error(self, error): def _classify_error(self, error):
"""Classify error based on error message content"""
error_str = str(error).lower() error_str = str(error).lower()
keywords_mapping = [ keywords_mapping = [
@ -139,89 +136,7 @@ class Base(ABC):
return gen_conf return gen_conf
def _bridge_sync_stream(self, gen): async def _async_chat_streamly(self, history, gen_conf, **kwargs):
"""Run a sync generator in a thread and yield asynchronously."""
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()
def worker():
try:
for item in gen:
loop.call_soon_threadsafe(queue.put_nowait, item)
except Exception as exc: # pragma: no cover - defensive
loop.call_soon_threadsafe(queue.put_nowait, exc)
finally:
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
threading.Thread(target=worker, daemon=True).start()
return queue
def _chat(self, history, gen_conf, **kwargs):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwq") >= 0:
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using _chat_streamly")
final_ans = ""
tol_token = 0
for delta, tol in self._chat_streamly(history, gen_conf, with_reasoning=False, **kwargs):
if delta.startswith("<think>") or delta.endswith("</think>"):
continue
final_ans += delta
tol_token = tol
if len(final_ans.strip()) == 0:
final_ans = "**ERROR**: Empty response from reasoning model"
return final_ans.strip(), tol_token
if self.model_name.lower().find("qwen3") >= 0:
kwargs["extra_body"] = {"enable_thinking": False}
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
if not response.choices or not response.choices[0].message or not response.choices[0].message.content:
return "", 0
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
ans = self._length_stop(ans)
return ans, total_token_count_from_response(response)
def _chat_streamly(self, history, gen_conf, **kwargs):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
reasoning_start = False
if kwargs.get("stop") or "stop" in gen_conf:
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf, stop=kwargs.get("stop"))
else:
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
for resp in response:
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += resp.choices[0].delta.reasoning_content + "</think>"
else:
reasoning_start = False
ans = resp.choices[0].delta.content
tol = total_token_count_from_response(resp)
if not tol:
tol = num_tokens_from_string(resp.choices[0].delta.content)
if resp.choices[0].finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans, tol
async def _async_chat_stream(self, history, gen_conf, **kwargs):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4)) logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
reasoning_start = False reasoning_start = False
@ -265,13 +180,19 @@ class Base(ABC):
gen_conf = self._clean_conf(gen_conf) gen_conf = self._clean_conf(gen_conf)
ans = "" ans = ""
total_tokens = 0 total_tokens = 0
for attempt in range(self.max_retries + 1):
try: try:
async for delta_ans, tol in self._async_chat_stream(history, gen_conf, **kwargs): async for delta_ans, tol in self._async_chat_streamly(history, gen_conf, **kwargs):
ans = delta_ans ans = delta_ans
total_tokens += tol total_tokens += tol
yield delta_ans yield ans
except openai.APIError as e: except Exception as e:
yield ans + "\n**ERROR**: " + str(e) e = await self._exceptions_async(e, attempt)
if e:
yield e
yield total_tokens
return
yield total_tokens yield total_tokens
@ -307,7 +228,7 @@ class Base(ABC):
logging.error(f"sync base giving up: {msg}") logging.error(f"sync base giving up: {msg}")
return msg return msg
async def _exceptions_async(self, e, attempt) -> str | None: async def _exceptions_async(self, e, attempt):
logging.exception("OpenAI async completion") logging.exception("OpenAI async completion")
error_code = self._classify_error(e) error_code = self._classify_error(e)
if attempt == self.max_retries: if attempt == self.max_retries:
@ -357,61 +278,6 @@ class Base(ABC):
self.toolcall_session = toolcall_session self.toolcall_session = toolcall_session
self.tools = tools self.tools = tools
def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf)
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
ans = ""
tk_count = 0
hist = deepcopy(history)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
history = hist
try:
for _ in range(self.max_rounds + 1):
logging.info(f"{self.tools=}")
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
tk_count += total_token_count_from_response(response)
if any([not response.choices, not response.choices[0].message]):
raise Exception(f"500 response structure error. Response: {response}")
if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
ans += "<think>" + response.choices[0].message.reasoning_content + "</think>"
ans += response.choices[0].message.content
if response.choices[0].finish_reason == "length":
ans = self._length_stop(ans)
return ans, tk_count
for tool_call in response.choices[0].message.tool_calls:
logging.info(f"Response {tool_call=}")
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = self.toolcall_session.tool_call(name, args)
history = self._append_history(history, tool_call, tool_response)
ans += self._verbose_tool_use(name, args, tool_response)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
ans += self._verbose_tool_use(name, {}, str(e))
logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
response, token_count = self._chat(history, gen_conf)
ans += response
tk_count += token_count
return ans, tk_count
except Exception as e:
e = self._exceptions(e, attempt)
if e:
return e, tk_count
assert False, "Shouldn't be here."
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}): async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf) gen_conf = self._clean_conf(gen_conf)
if system and history and history[0].get("role") != "system": if system and history and history[0].get("role") != "system":
@ -466,140 +332,6 @@ class Base(ABC):
assert False, "Shouldn't be here." assert False, "Shouldn't be here."
def chat(self, system, history, gen_conf={}, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
try:
return self._chat(history, gen_conf, **kwargs)
except Exception as e:
e = self._exceptions(e, attempt)
if e:
return e, 0
assert False, "Shouldn't be here."
def _wrap_toolcall_message(self, stream):
final_tool_calls = {}
for chunk in stream:
for tool_call in chunk.choices[0].delta.tool_calls or []:
index = tool_call.index
if index not in final_tool_calls:
final_tool_calls[index] = tool_call
final_tool_calls[index].function.arguments += tool_call.function.arguments
return final_tool_calls
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf)
tools = self.tools
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
total_tokens = 0
hist = deepcopy(history)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
history = hist
try:
for _ in range(self.max_rounds + 1):
reasoning_start = False
logging.info(f"{tools=}")
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
final_tool_calls = {}
answer = ""
for resp in response:
if resp.choices[0].delta.tool_calls:
for tool_call in resp.choices[0].delta.tool_calls or []:
index = tool_call.index
if index not in final_tool_calls:
if not tool_call.function.arguments:
tool_call.function.arguments = ""
final_tool_calls[index] = tool_call
else:
final_tool_calls[index].function.arguments += tool_call.function.arguments if tool_call.function.arguments else ""
continue
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
raise Exception("500 response structure error.")
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += resp.choices[0].delta.reasoning_content + "</think>"
yield ans
else:
reasoning_start = False
answer += resp.choices[0].delta.content
yield resp.choices[0].delta.content
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else:
total_tokens = tol
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
if finish_reason == "length":
yield self._length_stop("")
if answer:
yield total_tokens
return
for tool_call in final_tool_calls.values():
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
yield self._verbose_tool_use(name, args, "Begin to call...")
tool_response = self.toolcall_session.tool_call(name, args)
history = self._append_history(history, tool_call, tool_response)
yield self._verbose_tool_use(name, args, tool_response)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
yield self._verbose_tool_use(name, {}, str(e))
logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
for resp in response:
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
raise Exception("500 response structure error.")
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
continue
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else:
total_tokens = tol
answer += resp.choices[0].delta.content
yield resp.choices[0].delta.content
yield total_tokens
return
except Exception as e:
e = self._exceptions(e, attempt)
if e:
yield e
yield total_tokens
return
assert False, "Shouldn't be here."
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}): async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf) gen_conf = self._clean_conf(gen_conf)
tools = self.tools tools = self.tools
@ -715,9 +447,10 @@ class Base(ABC):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2)) logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwq") >= 0: if self.model_name.lower().find("qwq") >= 0:
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using async_chat_streamly") logging.info(f"[INFO] {self.model_name} detected as reasoning model, using async_chat_streamly")
final_ans = "" final_ans = ""
tol_token = 0 tol_token = 0
async for delta, tol in self._async_chat_stream(history, gen_conf, with_reasoning=False, **kwargs): async for delta, tol in self._async_chat_streamly(history, gen_conf, with_reasoning=False, **kwargs):
if delta.startswith("<think>") or delta.endswith("</think>"): if delta.startswith("<think>") or delta.endswith("</think>"):
continue continue
final_ans += delta final_ans += delta
@ -754,57 +487,6 @@ class Base(ABC):
return e, 0 return e, 0
assert False, "Shouldn't be here." assert False, "Shouldn't be here."
def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
ans = ""
total_tokens = 0
try:
for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs):
yield delta_ans
total_tokens += tol
except openai.APIError as e:
yield ans + "\n**ERROR**: " + str(e)
yield total_tokens
def _calculate_dynamic_ctx(self, history):
"""Calculate dynamic context window size"""
def count_tokens(text):
"""Calculate token count for text"""
# Simple calculation: 1 token per ASCII character
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
total = 0
for char in text:
if ord(char) < 128: # ASCII characters
total += 1
else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.)
total += 2
return total
# Calculate total tokens for all messages
total_tokens = 0
for message in history:
content = message.get("content", "")
# Calculate content tokens
content_tokens = count_tokens(content)
# Add role marker token overhead
role_tokens = 4
total_tokens += content_tokens + role_tokens
# Apply 1.2x buffer ratio
total_tokens_with_buffer = int(total_tokens * 1.2)
if total_tokens_with_buffer <= 8192:
ctx_size = 8192
else:
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
ctx_size = ctx_multiplier * 8192
return ctx_size
class GptTurbo(Base): class GptTurbo(Base):
_FACTORY_NAME = "OpenAI" _FACTORY_NAME = "OpenAI"
@ -1504,16 +1186,6 @@ class GoogleChat(Base):
yield total_tokens yield total_tokens
class GPUStackChat(Base):
_FACTORY_NAME = "GPUStack"
def __init__(self, key=None, model_name="", base_url="", **kwargs):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
super().__init__(key, model_name, base_url, **kwargs)
class TokenPonyChat(Base): class TokenPonyChat(Base):
_FACTORY_NAME = "TokenPony" _FACTORY_NAME = "TokenPony"
@ -1523,15 +1195,6 @@ class TokenPonyChat(Base):
super().__init__(key, model_name, base_url, **kwargs) super().__init__(key, model_name, base_url, **kwargs)
class DeerAPIChat(Base):
_FACTORY_NAME = "DeerAPI"
def __init__(self, key, model_name, base_url="https://api.deerapi.com/v1", **kwargs):
if not base_url:
base_url = "https://api.deerapi.com/v1"
super().__init__(key, model_name, base_url, **kwargs)
class LiteLLMBase(ABC): class LiteLLMBase(ABC):
_FACTORY_NAME = [ _FACTORY_NAME = [
"Tongyi-Qianwen", "Tongyi-Qianwen",
@ -1562,6 +1225,8 @@ class LiteLLMBase(ABC):
"Jiekou.AI", "Jiekou.AI",
"ZHIPU-AI", "ZHIPU-AI",
"MiniMax", "MiniMax",
"DeerAPI",
"GPUStack",
] ]
def __init__(self, key, model_name, base_url=None, **kwargs): def __init__(self, key, model_name, base_url=None, **kwargs):
@ -1589,11 +1254,9 @@ class LiteLLMBase(ABC):
self.provider_order = json.loads(key).get("provider_order", "") self.provider_order = json.loads(key).get("provider_order", "")
def _get_delay(self): def _get_delay(self):
"""Calculate retry delay time"""
return self.base_delay * random.uniform(10, 150) return self.base_delay * random.uniform(10, 150)
def _classify_error(self, error): def _classify_error(self, error):
"""Classify error based on error message content"""
error_str = str(error).lower() error_str = str(error).lower()
keywords_mapping = [ keywords_mapping = [
@ -1619,72 +1282,6 @@ class LiteLLMBase(ABC):
del gen_conf["max_tokens"] del gen_conf["max_tokens"]
return gen_conf return gen_conf
def _chat(self, history, gen_conf, **kwargs):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwen3") >= 0:
kwargs["extra_body"] = {"enable_thinking": False}
completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf)
response = litellm.completion(
**completion_args,
drop_params=True,
timeout=self.timeout,
)
# response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
return "", 0
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
ans = self._length_stop(ans)
return ans, total_token_count_from_response(response)
def _chat_streamly(self, history, gen_conf, **kwargs):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
gen_conf = self._clean_conf(gen_conf)
reasoning_start = False
completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf)
stop = kwargs.get("stop")
if stop:
completion_args["stop"] = stop
response = litellm.completion(
**completion_args,
drop_params=True,
timeout=self.timeout,
)
for resp in response:
if not hasattr(resp, "choices") or not resp.choices:
continue
delta = resp.choices[0].delta
if not hasattr(delta, "content") or delta.content is None:
delta.content = ""
if kwargs.get("with_reasoning", True) and hasattr(delta, "reasoning_content") and delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += delta.reasoning_content + "</think>"
else:
reasoning_start = False
ans = delta.content
tol = total_token_count_from_response(resp)
if not tol:
tol = num_tokens_from_string(delta.content)
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
if finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans, tol
async def async_chat(self, system, history, gen_conf, **kwargs): async def async_chat(self, system, history, gen_conf, **kwargs):
hist = list(history) if history else [] hist = list(history) if history else []
if system: if system:
@ -1795,22 +1392,7 @@ class LiteLLMBase(ABC):
def _should_retry(self, error_code: str) -> bool: def _should_retry(self, error_code: str) -> bool:
return error_code in self._retryable_errors return error_code in self._retryable_errors
def _exceptions(self, e, attempt) -> str | None: async def _exceptions_async(self, e, attempt):
logging.exception("OpenAI chat_with_tools")
# Classify the error
error_code = self._classify_error(e)
if attempt == self.max_retries:
error_code = LLMErrorCode.ERROR_MAX_RETRIES
if self._should_retry(error_code):
delay = self._get_delay()
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
time.sleep(delay)
return None
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
async def _exceptions_async(self, e, attempt) -> str | None:
logging.exception("LiteLLMBase async completion") logging.exception("LiteLLMBase async completion")
error_code = self._classify_error(e) error_code = self._classify_error(e)
if attempt == self.max_retries: if attempt == self.max_retries:
@ -1859,71 +1441,7 @@ class LiteLLMBase(ABC):
self.toolcall_session = toolcall_session self.toolcall_session = toolcall_session
self.tools = tools self.tools = tools
def _construct_completion_args(self, history, stream: bool, tools: bool, **kwargs): async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
completion_args = {
"model": self.model_name,
"messages": history,
"api_key": self.api_key,
"num_retries": self.max_retries,
**kwargs,
}
if stream:
completion_args.update(
{
"stream": stream,
}
)
if tools and self.tools:
completion_args.update(
{
"tools": self.tools,
"tool_choice": "auto",
}
)
if self.provider in FACTORY_DEFAULT_BASE_URL:
completion_args.update({"api_base": self.base_url})
elif self.provider == SupportedLiteLLMProvider.Bedrock:
completion_args.pop("api_key", None)
completion_args.pop("api_base", None)
completion_args.update(
{
"aws_access_key_id": self.bedrock_ak,
"aws_secret_access_key": self.bedrock_sk,
"aws_region_name": self.bedrock_region,
}
)
if self.provider == SupportedLiteLLMProvider.OpenRouter:
if self.provider_order:
def _to_order_list(x):
if x is None:
return []
if isinstance(x, str):
return [s.strip() for s in x.split(",") if s.strip()]
if isinstance(x, (list, tuple)):
return [str(s).strip() for s in x if str(s).strip()]
return []
extra_body = {}
provider_cfg = {}
provider_order = _to_order_list(self.provider_order)
provider_cfg["order"] = provider_order
provider_cfg["allow_fallbacks"] = False
extra_body["provider"] = provider_cfg
completion_args.update({"extra_body": extra_body})
# Ollama deployments commonly sit behind a reverse proxy that enforces
# Bearer auth. Ensure the Authorization header is set when an API key
# is provided, while respecting any user-supplied headers. #11350
extra_headers = deepcopy(completion_args.get("extra_headers") or {})
if self.provider == SupportedLiteLLMProvider.Ollama and self.api_key and "Authorization" not in extra_headers:
extra_headers["Authorization"] = f"Bearer {self.api_key}"
if extra_headers:
completion_args["extra_headers"] = extra_headers
return completion_args
def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf) gen_conf = self._clean_conf(gen_conf)
if system and history and history[0].get("role") != "system": if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system}) history.insert(0, {"role": "system", "content": system})
@ -1931,16 +1449,14 @@ class LiteLLMBase(ABC):
ans = "" ans = ""
tk_count = 0 tk_count = 0
hist = deepcopy(history) hist = deepcopy(history)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1): for attempt in range(self.max_retries + 1):
history = deepcopy(hist) # deepcopy is required here history = deepcopy(hist)
try: try:
for _ in range(self.max_rounds + 1): for _ in range(self.max_rounds + 1):
logging.info(f"{self.tools=}") logging.info(f"{self.tools=}")
completion_args = self._construct_completion_args(history=history, stream=False, tools=True, **gen_conf) completion_args = self._construct_completion_args(history=history, stream=False, tools=True, **gen_conf)
response = litellm.completion( response = await litellm.acompletion(
**completion_args, **completion_args,
drop_params=True, drop_params=True,
timeout=self.timeout, timeout=self.timeout,
@ -1966,7 +1482,7 @@ class LiteLLMBase(ABC):
name = tool_call.function.name name = tool_call.function.name
try: try:
args = json_repair.loads(tool_call.function.arguments) args = json_repair.loads(tool_call.function.arguments)
tool_response = self.toolcall_session.tool_call(name, args) tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
history = self._append_history(history, tool_call, tool_response) history = self._append_history(history, tool_call, tool_response)
ans += self._verbose_tool_use(name, args, tool_response) ans += self._verbose_tool_use(name, args, tool_response)
except Exception as e: except Exception as e:
@ -1977,49 +1493,19 @@ class LiteLLMBase(ABC):
logging.warning(f"Exceed max rounds: {self.max_rounds}") logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
response, token_count = self._chat(history, gen_conf) response, token_count = await self.async_chat("", history, gen_conf)
ans += response ans += response
tk_count += token_count tk_count += token_count
return ans, tk_count return ans, tk_count
except Exception as e: except Exception as e:
e = self._exceptions(e, attempt) e = await self._exceptions_async(e, attempt)
if e: if e:
return e, tk_count return e, tk_count
assert False, "Shouldn't be here." assert False, "Shouldn't be here."
def chat(self, system, history, gen_conf={}, **kwargs): async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
try:
response = self._chat(history, gen_conf, **kwargs)
return response
except Exception as e:
e = self._exceptions(e, attempt)
if e:
return e, 0
assert False, "Shouldn't be here."
def _wrap_toolcall_message(self, stream):
final_tool_calls = {}
for chunk in stream:
for tool_call in chunk.choices[0].delta.tool_calls or []:
index = tool_call.index
if index not in final_tool_calls:
final_tool_calls[index] = tool_call
final_tool_calls[index].function.arguments += tool_call.function.arguments
return final_tool_calls
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf) gen_conf = self._clean_conf(gen_conf)
tools = self.tools tools = self.tools
if system and history and history[0].get("role") != "system": if system and history and history[0].get("role") != "system":
@ -2028,16 +1514,15 @@ class LiteLLMBase(ABC):
total_tokens = 0 total_tokens = 0
hist = deepcopy(history) hist = deepcopy(history)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1): for attempt in range(self.max_retries + 1):
history = deepcopy(hist) # deepcopy is required here history = deepcopy(hist)
try: try:
for _ in range(self.max_rounds + 1): for _ in range(self.max_rounds + 1):
reasoning_start = False reasoning_start = False
logging.info(f"{tools=}") logging.info(f"{tools=}")
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf) completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
response = litellm.completion( response = await litellm.acompletion(
**completion_args, **completion_args,
drop_params=True, drop_params=True,
timeout=self.timeout, timeout=self.timeout,
@ -2046,7 +1531,7 @@ class LiteLLMBase(ABC):
final_tool_calls = {} final_tool_calls = {}
answer = "" answer = ""
for resp in response: async for resp in response:
if not hasattr(resp, "choices") or not resp.choices: if not hasattr(resp, "choices") or not resp.choices:
continue continue
@ -2082,7 +1567,7 @@ class LiteLLMBase(ABC):
if not tol: if not tol:
total_tokens += num_tokens_from_string(delta.content) total_tokens += num_tokens_from_string(delta.content)
else: else:
total_tokens += tol total_tokens = tol
finish_reason = getattr(resp.choices[0], "finish_reason", "") finish_reason = getattr(resp.choices[0], "finish_reason", "")
if finish_reason == "length": if finish_reason == "length":
@ -2097,31 +1582,25 @@ class LiteLLMBase(ABC):
try: try:
args = json_repair.loads(tool_call.function.arguments) args = json_repair.loads(tool_call.function.arguments)
yield self._verbose_tool_use(name, args, "Begin to call...") yield self._verbose_tool_use(name, args, "Begin to call...")
tool_response = self.toolcall_session.tool_call(name, args) tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
history = self._append_history(history, tool_call, tool_response) history = self._append_history(history, tool_call, tool_response)
yield self._verbose_tool_use(name, args, tool_response) yield self._verbose_tool_use(name, args, tool_response)
except Exception as e: except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append( history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": f"Tool call error: \n{tool_call}\nException:\n{str(e)}",
}
)
yield self._verbose_tool_use(name, {}, str(e)) yield self._verbose_tool_use(name, {}, str(e))
logging.warning(f"Exceed max rounds: {self.max_rounds}") logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf) completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
response = litellm.completion( response = await litellm.acompletion(
**completion_args, **completion_args,
drop_params=True, drop_params=True,
timeout=self.timeout, timeout=self.timeout,
) )
for resp in response: async for resp in response:
if not hasattr(resp, "choices") or not resp.choices: if not hasattr(resp, "choices") or not resp.choices:
continue continue
delta = resp.choices[0].delta delta = resp.choices[0].delta
@ -2131,14 +1610,14 @@ class LiteLLMBase(ABC):
if not tol: if not tol:
total_tokens += num_tokens_from_string(delta.content) total_tokens += num_tokens_from_string(delta.content)
else: else:
total_tokens += tol total_tokens = tol
yield delta.content yield delta.content
yield total_tokens yield total_tokens
return return
except Exception as e: except Exception as e:
e = self._exceptions(e, attempt) e = await self._exceptions_async(e, attempt)
if e: if e:
yield e yield e
yield total_tokens yield total_tokens
@ -2146,53 +1625,71 @@ class LiteLLMBase(ABC):
assert False, "Shouldn't be here." assert False, "Shouldn't be here."
def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs): def _construct_completion_args(self, history, stream: bool, tools: bool, **kwargs):
if system and history and history[0].get("role") != "system": completion_args = {
history.insert(0, {"role": "system", "content": system}) "model": self.model_name,
gen_conf = self._clean_conf(gen_conf) "messages": history,
ans = "" "api_key": self.api_key,
total_tokens = 0 "num_retries": self.max_retries,
try: **kwargs,
for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs): }
yield delta_ans if stream:
total_tokens += tol completion_args.update(
except openai.APIError as e: {
yield ans + "\n**ERROR**: " + str(e) "stream": stream,
}
)
if tools and self.tools:
completion_args.update(
{
"tools": self.tools,
"tool_choice": "auto",
}
)
if self.provider in FACTORY_DEFAULT_BASE_URL:
completion_args.update({"api_base": self.base_url})
elif self.provider == SupportedLiteLLMProvider.Bedrock:
completion_args.pop("api_key", None)
completion_args.pop("api_base", None)
completion_args.update(
{
"aws_access_key_id": self.bedrock_ak,
"aws_secret_access_key": self.bedrock_sk,
"aws_region_name": self.bedrock_region,
}
)
elif self.provider == SupportedLiteLLMProvider.OpenRouter:
if self.provider_order:
yield total_tokens def _to_order_list(x):
if x is None:
return []
if isinstance(x, str):
return [s.strip() for s in x.split(",") if s.strip()]
if isinstance(x, (list, tuple)):
return [str(s).strip() for s in x if str(s).strip()]
return []
def _calculate_dynamic_ctx(self, history): extra_body = {}
"""Calculate dynamic context window size""" provider_cfg = {}
provider_order = _to_order_list(self.provider_order)
provider_cfg["order"] = provider_order
provider_cfg["allow_fallbacks"] = False
extra_body["provider"] = provider_cfg
completion_args.update({"extra_body": extra_body})
elif self.provider == SupportedLiteLLMProvider.GPUStack:
completion_args.update(
{
"api_base": self.base_url,
}
)
def count_tokens(text): # Ollama deployments commonly sit behind a reverse proxy that enforces
"""Calculate token count for text""" # Bearer auth. Ensure the Authorization header is set when an API key
# Simple calculation: 1 token per ASCII character # is provided, while respecting any user-supplied headers. #11350
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.) extra_headers = deepcopy(completion_args.get("extra_headers") or {})
total = 0 if self.provider == SupportedLiteLLMProvider.Ollama and self.api_key and "Authorization" not in extra_headers:
for char in text: extra_headers["Authorization"] = f"Bearer {self.api_key}"
if ord(char) < 128: # ASCII characters if extra_headers:
total += 1 completion_args["extra_headers"] = extra_headers
else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.) return completion_args
total += 2
return total
# Calculate total tokens for all messages
total_tokens = 0
for message in history:
content = message.get("content", "")
# Calculate content tokens
content_tokens = count_tokens(content)
# Add role marker token overhead
role_tokens = 4
total_tokens += content_tokens + role_tokens
# Apply 1.2x buffer ratio
total_tokens_with_buffer = int(total_tokens * 1.2)
if total_tokens_with_buffer <= 8192:
ctx_size = 8192
else:
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
ctx_size = ctx_multiplier * 8192
return ctx_size