Refa: cleanup synchronous functions in chat_model and implement synchronization for conversation and dialog chats (#11779)

### What problem does this PR solve?

Cleanup synchronous functions in chat_model and implement
synchronization for conversation and dialog chats.

### Type of change

- [x] Refactoring
- [x] Performance Improvement
This commit is contained in:
Yongteng Lei
2025-12-08 09:43:03 +08:00
committed by GitHub
parent 9b8971a9de
commit 51ec708c58
10 changed files with 421 additions and 843 deletions

View File

@ -23,7 +23,7 @@ from quart import Response, request
from api.apps import current_user, login_required
from api.db.db_models import APIToken
from api.db.services.conversation_service import ConversationService, structure_answer
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap
from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap
from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService
from api.db.services.tenant_llm_service import TenantLLMService
@ -218,10 +218,10 @@ async def completion():
dia.llm_setting = chat_model_config
is_embedded = bool(chat_model_id)
def stream():
async def stream():
nonlocal dia, msg, req, conv
try:
for ans in chat(dia, msg, True, **req):
async for ans in async_chat(dia, msg, True, **req):
ans = structure_answer(conv, ans, message_id, conv.id)
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
if not is_embedded:
@ -241,7 +241,7 @@ async def completion():
else:
answer = None
for ans in chat(dia, msg, **req):
async for ans in async_chat(dia, msg, **req):
answer = structure_answer(conv, ans, message_id, conv.id)
if not is_embedded:
ConversationService.update_by_id(conv.id, conv.to_dict())
@ -406,10 +406,10 @@ async def ask_about():
if search_app:
search_config = search_app.get("search_config", {})
def stream():
async def stream():
nonlocal req, uid
try:
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config):
async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"

View File

@ -34,8 +34,9 @@ async def set_api_key():
if not all([secret_key, public_key, host]):
return get_error_data_result(message="Missing required fields")
current_user_id = current_user.id
langfuse_keys = dict(
tenant_id=current_user.id,
tenant_id=current_user_id,
secret_key=secret_key,
public_key=public_key,
host=host,
@ -45,23 +46,24 @@ async def set_api_key():
if not langfuse.auth_check():
return get_error_data_result(message="Invalid Langfuse keys")
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user.id)
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user_id)
with DB.atomic():
try:
if not langfuse_entry:
TenantLangfuseService.save(**langfuse_keys)
else:
TenantLangfuseService.update_by_tenant(tenant_id=current_user.id, langfuse_keys=langfuse_keys)
TenantLangfuseService.update_by_tenant(tenant_id=current_user_id, langfuse_keys=langfuse_keys)
return get_json_result(data=langfuse_keys)
except Exception as e:
server_error_response(e)
return server_error_response(e)
@manager.route("/api_key", methods=["GET"]) # noqa: F821
@login_required
@validate_request()
def get_api_key():
langfuse_entry = TenantLangfuseService.filter_by_tenant_with_info(tenant_id=current_user.id)
current_user_id = current_user.id
langfuse_entry = TenantLangfuseService.filter_by_tenant_with_info(tenant_id=current_user_id)
if not langfuse_entry:
return get_json_result(message="Have not record any Langfuse keys.")
@ -72,7 +74,7 @@ def get_api_key():
except langfuse.api.core.api_error.ApiError as api_err:
return get_json_result(message=f"Error from Langfuse: {api_err}")
except Exception as e:
server_error_response(e)
return server_error_response(e)
langfuse_entry["project_id"] = langfuse.api.projects.get().dict()["data"][0]["id"]
langfuse_entry["project_name"] = langfuse.api.projects.get().dict()["data"][0]["name"]
@ -84,7 +86,8 @@ def get_api_key():
@login_required
@validate_request()
def delete_api_key():
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user.id)
current_user_id = current_user.id
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user_id)
if not langfuse_entry:
return get_json_result(message="Have not record any Langfuse keys.")
@ -93,4 +96,4 @@ def delete_api_key():
TenantLangfuseService.delete_model(langfuse_entry)
return get_json_result(data=True)
except Exception as e:
server_error_response(e)
return server_error_response(e)

View File

@ -74,7 +74,7 @@ async def set_api_key():
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
try:
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50})
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50})
if m.find("**ERROR**") >= 0:
raise Exception(m)
chat_passed = True
@ -217,7 +217,7 @@ async def add_llm():
**extra,
)
try:
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
if not tc and m.find("**ERROR**:") >= 0:
raise Exception(m)
except Exception as e:

View File

@ -26,9 +26,10 @@ from api.db.db_models import APIToken
from api.db.services.api_service import API4ConversationService
from api.db.services.canvas_service import UserCanvasService, completion_openai
from api.db.services.canvas_service import completion as agent_completion
from api.db.services.conversation_service import ConversationService, iframe_completion
from api.db.services.conversation_service import completion as rag_completion
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap, meta_filter
from api.db.services.conversation_service import ConversationService
from api.db.services.conversation_service import async_iframe_completion as iframe_completion
from api.db.services.conversation_service import async_completion as rag_completion
from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap, meta_filter
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
@ -141,7 +142,7 @@ async def chat_completion(tenant_id, chat_id):
return resp
else:
answer = None
for ans in rag_completion(tenant_id, chat_id, **req):
async for ans in rag_completion(tenant_id, chat_id, **req):
answer = ans
break
return get_result(data=answer)
@ -245,7 +246,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
# The value for the usage field on all chunks except for the last one will be null.
# The usage field on the last chunk contains token usage statistics for the entire request.
# The choices field on the last chunk will always be an empty array [].
def streamed_response_generator(chat_id, dia, msg):
async def streamed_response_generator(chat_id, dia, msg):
token_used = 0
answer_cache = ""
reasoning_cache = ""
@ -274,7 +275,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
}
try:
for ans in chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
async for ans in async_chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
last_ans = ans
answer = ans["answer"]
@ -342,7 +343,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
return resp
else:
answer = None
for ans in chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
async for ans in async_chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
# focus answer content only
answer = ans
break
@ -733,10 +734,10 @@ async def ask_about(tenant_id):
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
uid = tenant_id
def stream():
async def stream():
nonlocal req, uid
try:
for ans in ask(req["question"], req["kb_ids"], uid):
async for ans in async_ask(req["question"], req["kb_ids"], uid):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e:
yield "data:" + json.dumps(
@ -827,7 +828,7 @@ async def chatbot_completions(dialog_id):
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
for answer in iframe_completion(dialog_id, **req):
async for answer in iframe_completion(dialog_id, **req):
return get_result(data=answer)
@ -918,10 +919,10 @@ async def ask_about_embedded():
if search_app := SearchService.get_detail(search_id):
search_config = search_app.get("search_config", {})
def stream():
async def stream():
nonlocal req, uid
try:
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config):
async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e:
yield "data:" + json.dumps(

View File

@ -19,7 +19,7 @@ from common.constants import StatusEnum
from api.db.db_models import Conversation, DB
from api.db.services.api_service import API4ConversationService
from api.db.services.common_service import CommonService
from api.db.services.dialog_service import DialogService, chat
from api.db.services.dialog_service import DialogService, async_chat
from common.misc_utils import get_uuid
import json
@ -89,8 +89,7 @@ def structure_answer(conv, ans, message_id, session_id):
conv.reference[-1] = reference
return ans
def completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
async def async_completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
assert name, "`name` can not be empty."
dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
assert dia, "You do not own the chat."
@ -112,7 +111,7 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None
"reference": {},
"audio_binary": None,
"id": None,
"session_id": session_id
"session_id": session_id
}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
@ -148,7 +147,7 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None
if stream:
try:
for ans in chat(dia, msg, True, **kwargs):
async for ans in async_chat(dia, msg, True, **kwargs):
ans = structure_answer(conv, ans, message_id, session_id)
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
ConversationService.update_by_id(conv.id, conv.to_dict())
@ -160,14 +159,13 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None
else:
answer = None
for ans in chat(dia, msg, False, **kwargs):
async for ans in async_chat(dia, msg, False, **kwargs):
answer = structure_answer(conv, ans, message_id, session_id)
ConversationService.update_by_id(conv.id, conv.to_dict())
break
yield answer
def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs):
async def async_iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs):
e, dia = DialogService.get_by_id(dialog_id)
assert e, "Dialog not found"
if not session_id:
@ -222,7 +220,7 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg
if stream:
try:
for ans in chat(dia, msg, True, **kwargs):
async for ans in async_chat(dia, msg, True, **kwargs):
ans = structure_answer(conv, ans, message_id, session_id)
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
ensure_ascii=False) + "\n\n"
@ -235,7 +233,7 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg
else:
answer = None
for ans in chat(dia, msg, False, **kwargs):
async for ans in async_chat(dia, msg, False, **kwargs):
answer = structure_answer(conv, ans, message_id, session_id)
API4ConversationService.append_message(conv.id, conv.to_dict())
break

View File

@ -178,7 +178,8 @@ class DialogService(CommonService):
offset += limit
return res
def chat_solo(dialog, messages, stream=True):
async def async_chat_solo(dialog, messages, stream=True):
attachments = ""
if "files" in messages[-1]:
attachments = "\n\n".join(FileService.get_files(messages[-1]["files"]))
@ -197,7 +198,8 @@ def chat_solo(dialog, messages, stream=True):
if stream:
last_ans = ""
delta_ans = ""
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
answer = ""
async for ans in chat_mdl.async_chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
answer = ans
delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 16:
@ -208,7 +210,7 @@ def chat_solo(dialog, messages, stream=True):
if delta_ans:
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
else:
answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
user_content = msg[-1].get("content", "[content not available]")
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
@ -347,13 +349,12 @@ def meta_filter(metas: dict, filters: list[dict], logic: str = "and"):
return []
return list(doc_ids)
def chat(dialog, messages, stream=True, **kwargs):
async def async_chat(dialog, messages, stream=True, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
for ans in chat_solo(dialog, messages, stream):
async for ans in async_chat_solo(dialog, messages, stream):
yield ans
return None
return
chat_start_ts = timer()
@ -400,7 +401,7 @@ def chat(dialog, messages, stream=True, **kwargs):
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
if ans:
yield ans
return None
return
for p in prompt_config["parameters"]:
if p["key"] == "knowledge":
@ -508,7 +509,8 @@ def chat(dialog, messages, stream=True, **kwargs):
empty_res = prompt_config["empty_response"]
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
"audio_binary": tts(tts_mdl, empty_res)}
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
return
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
gen_conf = dialog.llm_setting
@ -612,7 +614,7 @@ def chat(dialog, messages, stream=True, **kwargs):
if stream:
last_ans = ""
answer = ""
for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
async for ans in chat_mdl.async_chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
if thought:
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
answer = ans
@ -626,14 +628,14 @@ def chat(dialog, messages, stream=True, **kwargs):
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
yield decorate_answer(thought + answer)
else:
answer = chat_mdl.chat(prompt + prompt4citation, msg[1:], gen_conf)
answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf)
user_content = msg[-1].get("content", "[content not available]")
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
res = decorate_answer(answer)
res["audio_binary"] = tts(tts_mdl, answer)
yield res
return None
return
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
@ -805,8 +807,7 @@ def tts(tts_mdl, text):
return None
return binascii.hexlify(bin).decode("utf-8")
def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
doc_ids = search_config.get("doc_ids", [])
rerank_mdl = None
kb_ids = search_config.get("kb_ids", kb_ids)
@ -880,7 +881,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
return {"answer": answer, "reference": refs}
answer = ""
for ans in chat_mdl.chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
async for ans in chat_mdl.async_chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
answer = ans
yield {"answer": answer, "reference": {}}
yield decorate_answer(answer)

View File

@ -25,14 +25,17 @@ Provides functionality for evaluating RAG system performance including:
- Configuration recommendations
"""
import asyncio
import logging
import queue
import threading
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from timeit import default_timer as timer
from api.db.db_models import EvaluationDataset, EvaluationCase, EvaluationRun, EvaluationResult
from api.db.services.common_service import CommonService
from api.db.services.dialog_service import DialogService, chat
from api.db.services.dialog_service import DialogService
from common.misc_utils import get_uuid
from common.time_utils import current_timestamp
from common.constants import StatusEnum
@ -357,6 +360,42 @@ class EvaluationService(CommonService):
answer = ""
retrieved_chunks = []
def _sync_from_async_gen(async_gen):
result_queue: queue.Queue = queue.Queue()
def runner():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
async def consume():
try:
async for item in async_gen:
result_queue.put(item)
except Exception as e:
result_queue.put(e)
finally:
result_queue.put(StopIteration)
loop.run_until_complete(consume())
loop.close()
threading.Thread(target=runner, daemon=True).start()
while True:
item = result_queue.get()
if item is StopIteration:
break
if isinstance(item, Exception):
raise item
yield item
def chat(dialog, messages, stream=True, **kwargs):
from api.db.services.dialog_service import async_chat
return _sync_from_async_gen(async_chat(dialog, messages, stream=stream, **kwargs))
for ans in chat(dialog, messages, stream=False):
if isinstance(ans, dict):
answer = ans.get("answer", "")

View File

@ -16,15 +16,17 @@
import asyncio
import inspect
import logging
import queue
import re
import threading
from common.token_utils import num_tokens_from_string
from functools import partial
from typing import Generator
from common.constants import LLMType
from api.db.db_models import LLM
from api.db.services.common_service import CommonService
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
from common.constants import LLMType
from common.token_utils import num_tokens_from_string
class LLMService(CommonService):
@ -33,6 +35,7 @@ class LLMService(CommonService):
def get_init_tenant_llm(user_id):
from common import settings
tenant_llm = []
model_configs = {
@ -193,7 +196,7 @@ class LLMBundle(LLM4Tenant):
generation = self.langfuse.start_generation(
trace_context=self.trace_context,
name="stream_transcription",
metadata={"model": self.llm_name}
metadata={"model": self.llm_name},
)
final_text = ""
used_tokens = 0
@ -217,32 +220,34 @@ class LLMBundle(LLM4Tenant):
if self.langfuse:
generation.update(
output={"output": final_text},
usage_details={"total_tokens": used_tokens}
usage_details={"total_tokens": used_tokens},
)
generation.end()
return
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="stream_transcription", metadata={"model": self.llm_name})
full_text, used_tokens = mdl.transcription(audio)
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens
):
logging.error(
f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}"
generation = self.langfuse.start_generation(
trace_context=self.trace_context,
name="stream_transcription",
metadata={"model": self.llm_name},
)
full_text, used_tokens = mdl.transcription(audio)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
logging.error(f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}")
if self.langfuse:
generation.update(
output={"output": full_text},
usage_details={"total_tokens": used_tokens}
usage_details={"total_tokens": used_tokens},
)
generation.end()
yield {
"event": "final",
"text": full_text,
"streaming": False
"streaming": False,
}
def tts(self, text: str) -> Generator[bytes, None, None]:
@ -289,61 +294,79 @@ class LLMBundle(LLM4Tenant):
return kwargs
else:
return {k: v for k, v in kwargs.items() if k in allowed_params}
def _run_coroutine_sync(self, coro):
try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(coro)
result_queue: queue.Queue = queue.Queue()
def runner():
try:
result_queue.put((True, asyncio.run(coro)))
except Exception as e:
result_queue.put((False, e))
thread = threading.Thread(target=runner, daemon=True)
thread.start()
thread.join()
success, value = result_queue.get_nowait()
if success:
return value
raise value
def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str:
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
return self._run_coroutine_sync(self.async_chat(system, history, gen_conf, **kwargs))
chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
if self.is_tools and self.mdl.is_tools:
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
def _sync_from_async_stream(self, async_gen_fn, *args, **kwargs):
result_queue: queue.Queue = queue.Queue()
use_kwargs = self._clean_param(chat_partial, **kwargs)
txt, used_tokens = chat_partial(**use_kwargs)
txt = self._remove_reasoning_content(txt)
def runner():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
async def consume():
try:
async for item in async_gen_fn(*args, **kwargs):
result_queue.put(item)
except Exception as e:
result_queue.put(e)
finally:
result_queue.put(StopIteration)
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
loop.run_until_complete(consume())
loop.close()
if self.langfuse:
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
generation.end()
threading.Thread(target=runner, daemon=True).start()
return txt
while True:
item = result_queue.get()
if item is StopIteration:
break
if isinstance(item, Exception):
raise item
yield item
def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
ans = ""
chat_partial = partial(self.mdl.chat_streamly, system, history, gen_conf)
total_tokens = 0
if self.is_tools and self.mdl.is_tools:
chat_partial = partial(self.mdl.chat_streamly_with_tools, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
for txt in chat_partial(**use_kwargs):
for txt in self._sync_from_async_stream(self.async_chat_streamly, system, history, gen_conf, **kwargs):
if isinstance(txt, int):
total_tokens = txt
if self.langfuse:
generation.update(output={"output": ans})
generation.end()
break
if txt.endswith("</think>"):
ans = ans[: -len("</think>")]
ans = txt[: -len("</think>")]
continue
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
ans += txt
# cancatination has beend done in async_chat_streamly
ans = txt
yield ans
if total_tokens > 0:
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
def _bridge_sync_stream(self, gen):
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()
@ -352,7 +375,7 @@ class LLMBundle(LLM4Tenant):
try:
for item in gen:
loop.call_soon_threadsafe(queue.put_nowait, item)
except Exception as e: # pragma: no cover
except Exception as e:
loop.call_soon_threadsafe(queue.put_nowait, e)
finally:
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
@ -361,18 +384,27 @@ class LLMBundle(LLM4Tenant):
return queue
async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
if self.is_tools and self.mdl.is_tools and hasattr(self.mdl, "chat_with_tools"):
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_with_tools"):
base_fn = self.mdl.async_chat_with_tools
elif hasattr(self.mdl, "async_chat"):
base_fn = self.mdl.async_chat
else:
raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
generation = None
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
chat_partial = partial(base_fn, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
if hasattr(self.mdl, "async_chat_with_tools") and self.is_tools and self.mdl.is_tools:
txt, used_tokens = await self.mdl.async_chat_with_tools(system, history, gen_conf, **use_kwargs)
elif hasattr(self.mdl, "async_chat"):
txt, used_tokens = await self.mdl.async_chat(system, history, gen_conf, **use_kwargs)
else:
txt, used_tokens = await asyncio.to_thread(chat_partial, **use_kwargs)
try:
txt, used_tokens = await chat_partial(**use_kwargs)
except Exception as e:
if generation:
generation.update(output={"error": str(e)})
generation.end()
raise
txt = self._remove_reasoning_content(txt)
if not self.verbose_tool_use:
@ -381,49 +413,51 @@ class LLMBundle(LLM4Tenant):
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
if generation:
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
generation.end()
return txt
async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
total_tokens = 0
ans = ""
if self.is_tools and self.mdl.is_tools:
if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_streamly_with_tools"):
stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
else:
elif hasattr(self.mdl, "async_chat_streamly"):
stream_fn = getattr(self.mdl, "async_chat_streamly", None)
else:
raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
generation = None
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
if stream_fn:
chat_partial = partial(stream_fn, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
async for txt in chat_partial(**use_kwargs):
if isinstance(txt, int):
total_tokens = txt
break
try:
async for txt in chat_partial(**use_kwargs):
if isinstance(txt, int):
total_tokens = txt
break
if txt.endswith("</think>"):
ans = ans[: -len("</think>")]
if txt.endswith("</think>"):
ans = ans[: -len("</think>")]
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
ans += txt
yield ans
ans += txt
yield ans
except Exception as e:
if generation:
generation.update(output={"error": str(e)})
generation.end()
raise
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
if generation:
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
generation.end()
return
chat_partial = partial(self.mdl.chat_streamly_with_tools if (self.is_tools and self.mdl.is_tools) else self.mdl.chat_streamly, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
queue = self._bridge_sync_stream(chat_partial(**use_kwargs))
while True:
item = await queue.get()
if item is StopAsyncIteration:
break
if isinstance(item, Exception):
raise item
if isinstance(item, int):
total_tokens = item
break
yield item
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))

View File

@ -52,6 +52,8 @@ class SupportedLiteLLMProvider(StrEnum):
JiekouAI = "Jiekou.AI"
ZHIPU_AI = "ZHIPU-AI"
MiniMax = "MiniMax"
DeerAPI = "DeerAPI"
GPUStack = "GPUStack"
FACTORY_DEFAULT_BASE_URL = {
@ -75,6 +77,7 @@ FACTORY_DEFAULT_BASE_URL = {
SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai",
SupportedLiteLLMProvider.ZHIPU_AI: "https://open.bigmodel.cn/api/paas/v4",
SupportedLiteLLMProvider.MiniMax: "https://api.minimaxi.com/v1",
SupportedLiteLLMProvider.DeerAPI: "https://api.deerapi.com/v1",
}
@ -108,6 +111,8 @@ LITELLM_PROVIDER_PREFIX = {
SupportedLiteLLMProvider.JiekouAI: "openai/",
SupportedLiteLLMProvider.ZHIPU_AI: "openai/",
SupportedLiteLLMProvider.MiniMax: "openai/",
SupportedLiteLLMProvider.DeerAPI: "openai/",
SupportedLiteLLMProvider.GPUStack: "openai/",
}
ChatModel = globals().get("ChatModel", {})

View File

@ -19,7 +19,6 @@ import logging
import os
import random
import re
import threading
import time
from abc import ABC
from copy import deepcopy
@ -78,11 +77,9 @@ class Base(ABC):
self.toolcall_sessions = {}
def _get_delay(self):
"""Calculate retry delay time"""
return self.base_delay * random.uniform(10, 150)
def _classify_error(self, error):
"""Classify error based on error message content"""
error_str = str(error).lower()
keywords_mapping = [
@ -139,89 +136,7 @@ class Base(ABC):
return gen_conf
def _bridge_sync_stream(self, gen):
"""Run a sync generator in a thread and yield asynchronously."""
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()
def worker():
try:
for item in gen:
loop.call_soon_threadsafe(queue.put_nowait, item)
except Exception as exc: # pragma: no cover - defensive
loop.call_soon_threadsafe(queue.put_nowait, exc)
finally:
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
threading.Thread(target=worker, daemon=True).start()
return queue
def _chat(self, history, gen_conf, **kwargs):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwq") >= 0:
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using _chat_streamly")
final_ans = ""
tol_token = 0
for delta, tol in self._chat_streamly(history, gen_conf, with_reasoning=False, **kwargs):
if delta.startswith("<think>") or delta.endswith("</think>"):
continue
final_ans += delta
tol_token = tol
if len(final_ans.strip()) == 0:
final_ans = "**ERROR**: Empty response from reasoning model"
return final_ans.strip(), tol_token
if self.model_name.lower().find("qwen3") >= 0:
kwargs["extra_body"] = {"enable_thinking": False}
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
if not response.choices or not response.choices[0].message or not response.choices[0].message.content:
return "", 0
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
ans = self._length_stop(ans)
return ans, total_token_count_from_response(response)
def _chat_streamly(self, history, gen_conf, **kwargs):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
reasoning_start = False
if kwargs.get("stop") or "stop" in gen_conf:
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf, stop=kwargs.get("stop"))
else:
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
for resp in response:
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += resp.choices[0].delta.reasoning_content + "</think>"
else:
reasoning_start = False
ans = resp.choices[0].delta.content
tol = total_token_count_from_response(resp)
if not tol:
tol = num_tokens_from_string(resp.choices[0].delta.content)
if resp.choices[0].finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans, tol
async def _async_chat_stream(self, history, gen_conf, **kwargs):
async def _async_chat_streamly(self, history, gen_conf, **kwargs):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
reasoning_start = False
@ -265,13 +180,19 @@ class Base(ABC):
gen_conf = self._clean_conf(gen_conf)
ans = ""
total_tokens = 0
try:
async for delta_ans, tol in self._async_chat_stream(history, gen_conf, **kwargs):
ans = delta_ans
total_tokens += tol
yield delta_ans
except openai.APIError as e:
yield ans + "\n**ERROR**: " + str(e)
for attempt in range(self.max_retries + 1):
try:
async for delta_ans, tol in self._async_chat_streamly(history, gen_conf, **kwargs):
ans = delta_ans
total_tokens += tol
yield ans
except Exception as e:
e = await self._exceptions_async(e, attempt)
if e:
yield e
yield total_tokens
return
yield total_tokens
@ -307,7 +228,7 @@ class Base(ABC):
logging.error(f"sync base giving up: {msg}")
return msg
async def _exceptions_async(self, e, attempt) -> str | None:
async def _exceptions_async(self, e, attempt):
logging.exception("OpenAI async completion")
error_code = self._classify_error(e)
if attempt == self.max_retries:
@ -357,61 +278,6 @@ class Base(ABC):
self.toolcall_session = toolcall_session
self.tools = tools
def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf)
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
ans = ""
tk_count = 0
hist = deepcopy(history)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
history = hist
try:
for _ in range(self.max_rounds + 1):
logging.info(f"{self.tools=}")
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
tk_count += total_token_count_from_response(response)
if any([not response.choices, not response.choices[0].message]):
raise Exception(f"500 response structure error. Response: {response}")
if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
ans += "<think>" + response.choices[0].message.reasoning_content + "</think>"
ans += response.choices[0].message.content
if response.choices[0].finish_reason == "length":
ans = self._length_stop(ans)
return ans, tk_count
for tool_call in response.choices[0].message.tool_calls:
logging.info(f"Response {tool_call=}")
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = self.toolcall_session.tool_call(name, args)
history = self._append_history(history, tool_call, tool_response)
ans += self._verbose_tool_use(name, args, tool_response)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
ans += self._verbose_tool_use(name, {}, str(e))
logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
response, token_count = self._chat(history, gen_conf)
ans += response
tk_count += token_count
return ans, tk_count
except Exception as e:
e = self._exceptions(e, attempt)
if e:
return e, tk_count
assert False, "Shouldn't be here."
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf)
if system and history and history[0].get("role") != "system":
@ -466,140 +332,6 @@ class Base(ABC):
assert False, "Shouldn't be here."
def chat(self, system, history, gen_conf={}, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
try:
return self._chat(history, gen_conf, **kwargs)
except Exception as e:
e = self._exceptions(e, attempt)
if e:
return e, 0
assert False, "Shouldn't be here."
def _wrap_toolcall_message(self, stream):
final_tool_calls = {}
for chunk in stream:
for tool_call in chunk.choices[0].delta.tool_calls or []:
index = tool_call.index
if index not in final_tool_calls:
final_tool_calls[index] = tool_call
final_tool_calls[index].function.arguments += tool_call.function.arguments
return final_tool_calls
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf)
tools = self.tools
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
total_tokens = 0
hist = deepcopy(history)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
history = hist
try:
for _ in range(self.max_rounds + 1):
reasoning_start = False
logging.info(f"{tools=}")
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
final_tool_calls = {}
answer = ""
for resp in response:
if resp.choices[0].delta.tool_calls:
for tool_call in resp.choices[0].delta.tool_calls or []:
index = tool_call.index
if index not in final_tool_calls:
if not tool_call.function.arguments:
tool_call.function.arguments = ""
final_tool_calls[index] = tool_call
else:
final_tool_calls[index].function.arguments += tool_call.function.arguments if tool_call.function.arguments else ""
continue
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
raise Exception("500 response structure error.")
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += resp.choices[0].delta.reasoning_content + "</think>"
yield ans
else:
reasoning_start = False
answer += resp.choices[0].delta.content
yield resp.choices[0].delta.content
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else:
total_tokens = tol
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
if finish_reason == "length":
yield self._length_stop("")
if answer:
yield total_tokens
return
for tool_call in final_tool_calls.values():
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
yield self._verbose_tool_use(name, args, "Begin to call...")
tool_response = self.toolcall_session.tool_call(name, args)
history = self._append_history(history, tool_call, tool_response)
yield self._verbose_tool_use(name, args, tool_response)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
yield self._verbose_tool_use(name, {}, str(e))
logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
for resp in response:
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
raise Exception("500 response structure error.")
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
continue
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else:
total_tokens = tol
answer += resp.choices[0].delta.content
yield resp.choices[0].delta.content
yield total_tokens
return
except Exception as e:
e = self._exceptions(e, attempt)
if e:
yield e
yield total_tokens
return
assert False, "Shouldn't be here."
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf)
tools = self.tools
@ -715,9 +447,10 @@ class Base(ABC):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwq") >= 0:
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using async_chat_streamly")
final_ans = ""
tol_token = 0
async for delta, tol in self._async_chat_stream(history, gen_conf, with_reasoning=False, **kwargs):
async for delta, tol in self._async_chat_streamly(history, gen_conf, with_reasoning=False, **kwargs):
if delta.startswith("<think>") or delta.endswith("</think>"):
continue
final_ans += delta
@ -754,57 +487,6 @@ class Base(ABC):
return e, 0
assert False, "Shouldn't be here."
def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
ans = ""
total_tokens = 0
try:
for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs):
yield delta_ans
total_tokens += tol
except openai.APIError as e:
yield ans + "\n**ERROR**: " + str(e)
yield total_tokens
def _calculate_dynamic_ctx(self, history):
"""Calculate dynamic context window size"""
def count_tokens(text):
"""Calculate token count for text"""
# Simple calculation: 1 token per ASCII character
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
total = 0
for char in text:
if ord(char) < 128: # ASCII characters
total += 1
else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.)
total += 2
return total
# Calculate total tokens for all messages
total_tokens = 0
for message in history:
content = message.get("content", "")
# Calculate content tokens
content_tokens = count_tokens(content)
# Add role marker token overhead
role_tokens = 4
total_tokens += content_tokens + role_tokens
# Apply 1.2x buffer ratio
total_tokens_with_buffer = int(total_tokens * 1.2)
if total_tokens_with_buffer <= 8192:
ctx_size = 8192
else:
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
ctx_size = ctx_multiplier * 8192
return ctx_size
class GptTurbo(Base):
_FACTORY_NAME = "OpenAI"
@ -1504,16 +1186,6 @@ class GoogleChat(Base):
yield total_tokens
class GPUStackChat(Base):
_FACTORY_NAME = "GPUStack"
def __init__(self, key=None, model_name="", base_url="", **kwargs):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
super().__init__(key, model_name, base_url, **kwargs)
class TokenPonyChat(Base):
_FACTORY_NAME = "TokenPony"
@ -1523,15 +1195,6 @@ class TokenPonyChat(Base):
super().__init__(key, model_name, base_url, **kwargs)
class DeerAPIChat(Base):
_FACTORY_NAME = "DeerAPI"
def __init__(self, key, model_name, base_url="https://api.deerapi.com/v1", **kwargs):
if not base_url:
base_url = "https://api.deerapi.com/v1"
super().__init__(key, model_name, base_url, **kwargs)
class LiteLLMBase(ABC):
_FACTORY_NAME = [
"Tongyi-Qianwen",
@ -1562,6 +1225,8 @@ class LiteLLMBase(ABC):
"Jiekou.AI",
"ZHIPU-AI",
"MiniMax",
"DeerAPI",
"GPUStack",
]
def __init__(self, key, model_name, base_url=None, **kwargs):
@ -1589,11 +1254,9 @@ class LiteLLMBase(ABC):
self.provider_order = json.loads(key).get("provider_order", "")
def _get_delay(self):
"""Calculate retry delay time"""
return self.base_delay * random.uniform(10, 150)
def _classify_error(self, error):
"""Classify error based on error message content"""
error_str = str(error).lower()
keywords_mapping = [
@ -1619,72 +1282,6 @@ class LiteLLMBase(ABC):
del gen_conf["max_tokens"]
return gen_conf
def _chat(self, history, gen_conf, **kwargs):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwen3") >= 0:
kwargs["extra_body"] = {"enable_thinking": False}
completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf)
response = litellm.completion(
**completion_args,
drop_params=True,
timeout=self.timeout,
)
# response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
return "", 0
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
ans = self._length_stop(ans)
return ans, total_token_count_from_response(response)
def _chat_streamly(self, history, gen_conf, **kwargs):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
gen_conf = self._clean_conf(gen_conf)
reasoning_start = False
completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf)
stop = kwargs.get("stop")
if stop:
completion_args["stop"] = stop
response = litellm.completion(
**completion_args,
drop_params=True,
timeout=self.timeout,
)
for resp in response:
if not hasattr(resp, "choices") or not resp.choices:
continue
delta = resp.choices[0].delta
if not hasattr(delta, "content") or delta.content is None:
delta.content = ""
if kwargs.get("with_reasoning", True) and hasattr(delta, "reasoning_content") and delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += delta.reasoning_content + "</think>"
else:
reasoning_start = False
ans = delta.content
tol = total_token_count_from_response(resp)
if not tol:
tol = num_tokens_from_string(delta.content)
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
if finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans, tol
async def async_chat(self, system, history, gen_conf, **kwargs):
hist = list(history) if history else []
if system:
@ -1795,22 +1392,7 @@ class LiteLLMBase(ABC):
def _should_retry(self, error_code: str) -> bool:
return error_code in self._retryable_errors
def _exceptions(self, e, attempt) -> str | None:
logging.exception("OpenAI chat_with_tools")
# Classify the error
error_code = self._classify_error(e)
if attempt == self.max_retries:
error_code = LLMErrorCode.ERROR_MAX_RETRIES
if self._should_retry(error_code):
delay = self._get_delay()
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
time.sleep(delay)
return None
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
async def _exceptions_async(self, e, attempt) -> str | None:
async def _exceptions_async(self, e, attempt):
logging.exception("LiteLLMBase async completion")
error_code = self._classify_error(e)
if attempt == self.max_retries:
@ -1859,71 +1441,7 @@ class LiteLLMBase(ABC):
self.toolcall_session = toolcall_session
self.tools = tools
def _construct_completion_args(self, history, stream: bool, tools: bool, **kwargs):
completion_args = {
"model": self.model_name,
"messages": history,
"api_key": self.api_key,
"num_retries": self.max_retries,
**kwargs,
}
if stream:
completion_args.update(
{
"stream": stream,
}
)
if tools and self.tools:
completion_args.update(
{
"tools": self.tools,
"tool_choice": "auto",
}
)
if self.provider in FACTORY_DEFAULT_BASE_URL:
completion_args.update({"api_base": self.base_url})
elif self.provider == SupportedLiteLLMProvider.Bedrock:
completion_args.pop("api_key", None)
completion_args.pop("api_base", None)
completion_args.update(
{
"aws_access_key_id": self.bedrock_ak,
"aws_secret_access_key": self.bedrock_sk,
"aws_region_name": self.bedrock_region,
}
)
if self.provider == SupportedLiteLLMProvider.OpenRouter:
if self.provider_order:
def _to_order_list(x):
if x is None:
return []
if isinstance(x, str):
return [s.strip() for s in x.split(",") if s.strip()]
if isinstance(x, (list, tuple)):
return [str(s).strip() for s in x if str(s).strip()]
return []
extra_body = {}
provider_cfg = {}
provider_order = _to_order_list(self.provider_order)
provider_cfg["order"] = provider_order
provider_cfg["allow_fallbacks"] = False
extra_body["provider"] = provider_cfg
completion_args.update({"extra_body": extra_body})
# Ollama deployments commonly sit behind a reverse proxy that enforces
# Bearer auth. Ensure the Authorization header is set when an API key
# is provided, while respecting any user-supplied headers. #11350
extra_headers = deepcopy(completion_args.get("extra_headers") or {})
if self.provider == SupportedLiteLLMProvider.Ollama and self.api_key and "Authorization" not in extra_headers:
extra_headers["Authorization"] = f"Bearer {self.api_key}"
if extra_headers:
completion_args["extra_headers"] = extra_headers
return completion_args
def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf)
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
@ -1931,16 +1449,14 @@ class LiteLLMBase(ABC):
ans = ""
tk_count = 0
hist = deepcopy(history)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
history = deepcopy(hist) # deepcopy is required here
history = deepcopy(hist)
try:
for _ in range(self.max_rounds + 1):
logging.info(f"{self.tools=}")
completion_args = self._construct_completion_args(history=history, stream=False, tools=True, **gen_conf)
response = litellm.completion(
response = await litellm.acompletion(
**completion_args,
drop_params=True,
timeout=self.timeout,
@ -1966,7 +1482,7 @@ class LiteLLMBase(ABC):
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = self.toolcall_session.tool_call(name, args)
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
history = self._append_history(history, tool_call, tool_response)
ans += self._verbose_tool_use(name, args, tool_response)
except Exception as e:
@ -1977,49 +1493,19 @@ class LiteLLMBase(ABC):
logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
response, token_count = self._chat(history, gen_conf)
response, token_count = await self.async_chat("", history, gen_conf)
ans += response
tk_count += token_count
return ans, tk_count
except Exception as e:
e = self._exceptions(e, attempt)
e = await self._exceptions_async(e, attempt)
if e:
return e, tk_count
assert False, "Shouldn't be here."
def chat(self, system, history, gen_conf={}, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
try:
response = self._chat(history, gen_conf, **kwargs)
return response
except Exception as e:
e = self._exceptions(e, attempt)
if e:
return e, 0
assert False, "Shouldn't be here."
def _wrap_toolcall_message(self, stream):
final_tool_calls = {}
for chunk in stream:
for tool_call in chunk.choices[0].delta.tool_calls or []:
index = tool_call.index
if index not in final_tool_calls:
final_tool_calls[index] = tool_call
final_tool_calls[index].function.arguments += tool_call.function.arguments
return final_tool_calls
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf)
tools = self.tools
if system and history and history[0].get("role") != "system":
@ -2028,16 +1514,15 @@ class LiteLLMBase(ABC):
total_tokens = 0
hist = deepcopy(history)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
history = deepcopy(hist) # deepcopy is required here
history = deepcopy(hist)
try:
for _ in range(self.max_rounds + 1):
reasoning_start = False
logging.info(f"{tools=}")
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
response = litellm.completion(
response = await litellm.acompletion(
**completion_args,
drop_params=True,
timeout=self.timeout,
@ -2046,7 +1531,7 @@ class LiteLLMBase(ABC):
final_tool_calls = {}
answer = ""
for resp in response:
async for resp in response:
if not hasattr(resp, "choices") or not resp.choices:
continue
@ -2082,7 +1567,7 @@ class LiteLLMBase(ABC):
if not tol:
total_tokens += num_tokens_from_string(delta.content)
else:
total_tokens += tol
total_tokens = tol
finish_reason = getattr(resp.choices[0], "finish_reason", "")
if finish_reason == "length":
@ -2097,31 +1582,25 @@ class LiteLLMBase(ABC):
try:
args = json_repair.loads(tool_call.function.arguments)
yield self._verbose_tool_use(name, args, "Begin to call...")
tool_response = self.toolcall_session.tool_call(name, args)
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
history = self._append_history(history, tool_call, tool_response)
yield self._verbose_tool_use(name, args, tool_response)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": f"Tool call error: \n{tool_call}\nException:\n{str(e)}",
}
)
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
yield self._verbose_tool_use(name, {}, str(e))
logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
response = litellm.completion(
response = await litellm.acompletion(
**completion_args,
drop_params=True,
timeout=self.timeout,
)
for resp in response:
async for resp in response:
if not hasattr(resp, "choices") or not resp.choices:
continue
delta = resp.choices[0].delta
@ -2131,14 +1610,14 @@ class LiteLLMBase(ABC):
if not tol:
total_tokens += num_tokens_from_string(delta.content)
else:
total_tokens += tol
total_tokens = tol
yield delta.content
yield total_tokens
return
except Exception as e:
e = self._exceptions(e, attempt)
e = await self._exceptions_async(e, attempt)
if e:
yield e
yield total_tokens
@ -2146,53 +1625,71 @@ class LiteLLMBase(ABC):
assert False, "Shouldn't be here."
def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
ans = ""
total_tokens = 0
try:
for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs):
yield delta_ans
total_tokens += tol
except openai.APIError as e:
yield ans + "\n**ERROR**: " + str(e)
def _construct_completion_args(self, history, stream: bool, tools: bool, **kwargs):
completion_args = {
"model": self.model_name,
"messages": history,
"api_key": self.api_key,
"num_retries": self.max_retries,
**kwargs,
}
if stream:
completion_args.update(
{
"stream": stream,
}
)
if tools and self.tools:
completion_args.update(
{
"tools": self.tools,
"tool_choice": "auto",
}
)
if self.provider in FACTORY_DEFAULT_BASE_URL:
completion_args.update({"api_base": self.base_url})
elif self.provider == SupportedLiteLLMProvider.Bedrock:
completion_args.pop("api_key", None)
completion_args.pop("api_base", None)
completion_args.update(
{
"aws_access_key_id": self.bedrock_ak,
"aws_secret_access_key": self.bedrock_sk,
"aws_region_name": self.bedrock_region,
}
)
elif self.provider == SupportedLiteLLMProvider.OpenRouter:
if self.provider_order:
yield total_tokens
def _to_order_list(x):
if x is None:
return []
if isinstance(x, str):
return [s.strip() for s in x.split(",") if s.strip()]
if isinstance(x, (list, tuple)):
return [str(s).strip() for s in x if str(s).strip()]
return []
def _calculate_dynamic_ctx(self, history):
"""Calculate dynamic context window size"""
extra_body = {}
provider_cfg = {}
provider_order = _to_order_list(self.provider_order)
provider_cfg["order"] = provider_order
provider_cfg["allow_fallbacks"] = False
extra_body["provider"] = provider_cfg
completion_args.update({"extra_body": extra_body})
elif self.provider == SupportedLiteLLMProvider.GPUStack:
completion_args.update(
{
"api_base": self.base_url,
}
)
def count_tokens(text):
"""Calculate token count for text"""
# Simple calculation: 1 token per ASCII character
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
total = 0
for char in text:
if ord(char) < 128: # ASCII characters
total += 1
else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.)
total += 2
return total
# Calculate total tokens for all messages
total_tokens = 0
for message in history:
content = message.get("content", "")
# Calculate content tokens
content_tokens = count_tokens(content)
# Add role marker token overhead
role_tokens = 4
total_tokens += content_tokens + role_tokens
# Apply 1.2x buffer ratio
total_tokens_with_buffer = int(total_tokens * 1.2)
if total_tokens_with_buffer <= 8192:
ctx_size = 8192
else:
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
ctx_size = ctx_multiplier * 8192
return ctx_size
# Ollama deployments commonly sit behind a reverse proxy that enforces
# Bearer auth. Ensure the Authorization header is set when an API key
# is provided, while respecting any user-supplied headers. #11350
extra_headers = deepcopy(completion_args.get("extra_headers") or {})
if self.provider == SupportedLiteLLMProvider.Ollama and self.api_key and "Authorization" not in extra_headers:
extra_headers["Authorization"] = f"Bearer {self.api_key}"
if extra_headers:
completion_args["extra_headers"] = extra_headers
return completion_args