mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refa: make RAGFlow more asynchronous (#11601)
### What problem does this PR solve? Try to make this more asynchronous. Verified in chat and agent scenarios, reducing blocking behavior. #11551, #11579. However, the impact of these changes still requires further investigation to ensure everything works as expected. ### Type of change - [x] Refactoring
This commit is contained in:
@ -19,6 +19,7 @@ import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
@ -28,10 +29,9 @@ import json_repair
|
||||
import litellm
|
||||
import openai
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
from strenum import StrEnum
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
from common.token_utils import num_tokens_from_string, total_token_count_from_response
|
||||
from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider
|
||||
@ -68,6 +68,7 @@ class Base(ABC):
|
||||
def __init__(self, key, model_name, base_url, **kwargs):
|
||||
timeout = int(os.environ.get("LLM_TIMEOUT_SECONDS", 600))
|
||||
self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
|
||||
self.async_client = AsyncOpenAI(api_key=key, base_url=base_url, timeout=timeout)
|
||||
self.model_name = model_name
|
||||
# Configure retry parameters
|
||||
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
|
||||
@ -139,6 +140,23 @@ class Base(ABC):
|
||||
|
||||
return gen_conf
|
||||
|
||||
def _bridge_sync_stream(self, gen):
|
||||
"""Run a sync generator in a thread and yield asynchronously."""
|
||||
loop = asyncio.get_running_loop()
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
def worker():
|
||||
try:
|
||||
for item in gen:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, item)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
loop.call_soon_threadsafe(queue.put_nowait, exc)
|
||||
finally:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
|
||||
|
||||
threading.Thread(target=worker, daemon=True).start()
|
||||
return queue
|
||||
|
||||
def _chat(self, history, gen_conf, **kwargs):
|
||||
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
||||
if self.model_name.lower().find("qwq") >= 0:
|
||||
@ -204,6 +222,60 @@ class Base(ABC):
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
yield ans, tol
|
||||
|
||||
async def _async_chat_stream(self, history, gen_conf, **kwargs):
|
||||
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
||||
reasoning_start = False
|
||||
|
||||
request_kwargs = {"model": self.model_name, "messages": history, "stream": True, **gen_conf}
|
||||
stop = kwargs.get("stop")
|
||||
if stop:
|
||||
request_kwargs["stop"] = stop
|
||||
|
||||
response = await self.async_client.chat.completions.create(**request_kwargs)
|
||||
|
||||
async for resp in response:
|
||||
if not resp.choices:
|
||||
continue
|
||||
if not resp.choices[0].delta.content:
|
||||
resp.choices[0].delta.content = ""
|
||||
if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
|
||||
ans = ""
|
||||
if not reasoning_start:
|
||||
reasoning_start = True
|
||||
ans = "<think>"
|
||||
ans += resp.choices[0].delta.reasoning_content + "</think>"
|
||||
else:
|
||||
reasoning_start = False
|
||||
ans = resp.choices[0].delta.content
|
||||
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
tol = num_tokens_from_string(resp.choices[0].delta.content)
|
||||
|
||||
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
|
||||
if finish_reason == "length":
|
||||
if is_chinese(ans):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
yield ans, tol
|
||||
|
||||
async def async_chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
try:
|
||||
async for delta_ans, tol in self._async_chat_stream(history, gen_conf, **kwargs):
|
||||
ans = delta_ans
|
||||
total_tokens += tol
|
||||
yield delta_ans
|
||||
except openai.APIError as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield total_tokens
|
||||
|
||||
def _length_stop(self, ans):
|
||||
if is_chinese([ans]):
|
||||
return ans + LENGTH_NOTIFICATION_CN
|
||||
@ -232,7 +304,25 @@ class Base(ABC):
|
||||
time.sleep(delay)
|
||||
return None
|
||||
|
||||
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
||||
msg = f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
||||
logging.error(f"sync base giving up: {msg}")
|
||||
return msg
|
||||
|
||||
async def _exceptions_async(self, e, attempt) -> str | None:
|
||||
logging.exception("OpenAI async completion")
|
||||
error_code = self._classify_error(e)
|
||||
if attempt == self.max_retries:
|
||||
error_code = LLMErrorCode.ERROR_MAX_RETRIES
|
||||
|
||||
if self._should_retry(error_code):
|
||||
delay = self._get_delay()
|
||||
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
||||
await asyncio.sleep(delay)
|
||||
return None
|
||||
|
||||
msg = f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
||||
logging.error(f"async base giving up: {msg}")
|
||||
return msg
|
||||
|
||||
def _verbose_tool_use(self, name, args, res):
|
||||
return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"
|
||||
@ -323,6 +413,60 @@ class Base(ABC):
|
||||
|
||||
assert False, "Shouldn't be here."
|
||||
|
||||
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
hist = deepcopy(history)
|
||||
for attempt in range(self.max_retries + 1):
|
||||
history = deepcopy(hist)
|
||||
try:
|
||||
for _ in range(self.max_rounds + 1):
|
||||
logging.info(f"{self.tools=}")
|
||||
response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
|
||||
tk_count += total_token_count_from_response(response)
|
||||
if any([not response.choices, not response.choices[0].message]):
|
||||
raise Exception(f"500 response structure error. Response: {response}")
|
||||
|
||||
if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
|
||||
if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
|
||||
ans += "<think>" + response.choices[0].message.reasoning_content + "</think>"
|
||||
|
||||
ans += response.choices[0].message.content
|
||||
if response.choices[0].finish_reason == "length":
|
||||
ans = self._length_stop(ans)
|
||||
|
||||
return ans, tk_count
|
||||
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
logging.info(f"Response {tool_call=}")
|
||||
name = tool_call.function.name
|
||||
try:
|
||||
args = json_repair.loads(tool_call.function.arguments)
|
||||
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
|
||||
history = self._append_history(history, tool_call, tool_response)
|
||||
ans += self._verbose_tool_use(name, args, tool_response)
|
||||
except Exception as e:
|
||||
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
|
||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
||||
ans += self._verbose_tool_use(name, {}, str(e))
|
||||
|
||||
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
||||
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
||||
response, token_count = await self._async_chat(history, gen_conf)
|
||||
ans += response
|
||||
tk_count += token_count
|
||||
return ans, tk_count
|
||||
except Exception as e:
|
||||
e = await self._exceptions_async(e, attempt)
|
||||
if e:
|
||||
return e, tk_count
|
||||
|
||||
assert False, "Shouldn't be here."
|
||||
|
||||
def chat(self, system, history, gen_conf={}, **kwargs):
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
@ -457,6 +601,160 @@ class Base(ABC):
|
||||
|
||||
assert False, "Shouldn't be here."
|
||||
|
||||
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
tools = self.tools
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
|
||||
total_tokens = 0
|
||||
hist = deepcopy(history)
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
history = deepcopy(hist)
|
||||
try:
|
||||
for _ in range(self.max_rounds + 1):
|
||||
reasoning_start = False
|
||||
logging.info(f"{tools=}")
|
||||
|
||||
response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
|
||||
|
||||
final_tool_calls = {}
|
||||
answer = ""
|
||||
|
||||
async for resp in response:
|
||||
if not hasattr(resp, "choices") or not resp.choices:
|
||||
continue
|
||||
|
||||
delta = resp.choices[0].delta
|
||||
|
||||
if hasattr(delta, "tool_calls") and delta.tool_calls:
|
||||
for tool_call in delta.tool_calls:
|
||||
index = tool_call.index
|
||||
if index not in final_tool_calls:
|
||||
if not tool_call.function.arguments:
|
||||
tool_call.function.arguments = ""
|
||||
final_tool_calls[index] = tool_call
|
||||
else:
|
||||
final_tool_calls[index].function.arguments += tool_call.function.arguments or ""
|
||||
continue
|
||||
|
||||
if not hasattr(delta, "content") or delta.content is None:
|
||||
delta.content = ""
|
||||
|
||||
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
|
||||
ans = ""
|
||||
if not reasoning_start:
|
||||
reasoning_start = True
|
||||
ans = "<think>"
|
||||
ans += delta.reasoning_content + "</think>"
|
||||
yield ans
|
||||
else:
|
||||
reasoning_start = False
|
||||
answer += delta.content
|
||||
yield delta.content
|
||||
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(delta.content)
|
||||
else:
|
||||
total_tokens = tol
|
||||
|
||||
finish_reason = getattr(resp.choices[0], "finish_reason", "")
|
||||
if finish_reason == "length":
|
||||
yield self._length_stop("")
|
||||
|
||||
if answer:
|
||||
yield total_tokens
|
||||
return
|
||||
|
||||
for tool_call in final_tool_calls.values():
|
||||
name = tool_call.function.name
|
||||
try:
|
||||
args = json_repair.loads(tool_call.function.arguments)
|
||||
yield self._verbose_tool_use(name, args, "Begin to call...")
|
||||
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
|
||||
history = self._append_history(history, tool_call, tool_response)
|
||||
yield self._verbose_tool_use(name, args, tool_response)
|
||||
except Exception as e:
|
||||
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
|
||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
||||
yield self._verbose_tool_use(name, {}, str(e))
|
||||
|
||||
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
||||
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
||||
|
||||
response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
|
||||
|
||||
async for resp in response:
|
||||
if not hasattr(resp, "choices") or not resp.choices:
|
||||
continue
|
||||
delta = resp.choices[0].delta
|
||||
if not hasattr(delta, "content") or delta.content is None:
|
||||
continue
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(delta.content)
|
||||
else:
|
||||
total_tokens = tol
|
||||
yield delta.content
|
||||
|
||||
yield total_tokens
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
e = await self._exceptions_async(e, attempt)
|
||||
if e:
|
||||
logging.error(f"async_chat_streamly failed: {e}")
|
||||
yield e
|
||||
yield total_tokens
|
||||
return
|
||||
|
||||
assert False, "Shouldn't be here."
|
||||
|
||||
async def _async_chat(self, history, gen_conf, **kwargs):
|
||||
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
||||
if self.model_name.lower().find("qwq") >= 0:
|
||||
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using async_chat_streamly")
|
||||
final_ans = ""
|
||||
tol_token = 0
|
||||
async for delta, tol in self._async_chat_stream(history, gen_conf, with_reasoning=False, **kwargs):
|
||||
if delta.startswith("<think>") or delta.endswith("</think>"):
|
||||
continue
|
||||
final_ans += delta
|
||||
tol_token = tol
|
||||
|
||||
if len(final_ans.strip()) == 0:
|
||||
final_ans = "**ERROR**: Empty response from reasoning model"
|
||||
|
||||
return final_ans.strip(), tol_token
|
||||
|
||||
if self.model_name.lower().find("qwen3") >= 0:
|
||||
kwargs["extra_body"] = {"enable_thinking": False}
|
||||
|
||||
response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
||||
|
||||
if not response.choices or not response.choices[0].message or not response.choices[0].message.content:
|
||||
return "", 0
|
||||
ans = response.choices[0].message.content.strip()
|
||||
if response.choices[0].finish_reason == "length":
|
||||
ans = self._length_stop(ans)
|
||||
return ans, total_token_count_from_response(response)
|
||||
|
||||
async def async_chat(self, system, history, gen_conf={}, **kwargs):
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return await self._async_chat(history, gen_conf, **kwargs)
|
||||
except Exception as e:
|
||||
e = await self._exceptions_async(e, attempt)
|
||||
if e:
|
||||
return e, 0
|
||||
assert False, "Shouldn't be here."
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
@ -642,66 +940,6 @@ class BaiChuanChat(Base):
|
||||
yield total_tokens
|
||||
|
||||
|
||||
class ZhipuChat(Base):
|
||||
_FACTORY_NAME = "ZHIPU-AI"
|
||||
|
||||
def __init__(self, key, model_name="glm-3-turbo", base_url=None, **kwargs):
|
||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||
|
||||
self.client = ZhipuAI(api_key=key)
|
||||
self.model_name = model_name
|
||||
|
||||
def _clean_conf(self, gen_conf):
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
gen_conf = self._clean_conf_plealty(gen_conf)
|
||||
return gen_conf
|
||||
|
||||
def _clean_conf_plealty(self, gen_conf):
|
||||
if "presence_penalty" in gen_conf:
|
||||
del gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
del gen_conf["frequency_penalty"]
|
||||
return gen_conf
|
||||
|
||||
def chat_with_tools(self, system: str, history: list, gen_conf: dict):
|
||||
gen_conf = self._clean_conf_plealty(gen_conf)
|
||||
|
||||
return super().chat_with_tools(system, history, gen_conf)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
try:
|
||||
logging.info(json.dumps(history, ensure_ascii=False, indent=2))
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
||||
for resp in response:
|
||||
if not resp.choices[0].delta.content:
|
||||
continue
|
||||
delta = resp.choices[0].delta.content
|
||||
ans = delta
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
if is_chinese(ans):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
tk_count = total_token_count_from_response(resp)
|
||||
if resp.choices[0].finish_reason == "stop":
|
||||
tk_count = total_token_count_from_response(resp)
|
||||
yield ans
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield tk_count
|
||||
|
||||
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict):
|
||||
gen_conf = self._clean_conf_plealty(gen_conf)
|
||||
return super().chat_streamly_with_tools(system, history, gen_conf)
|
||||
|
||||
|
||||
class LocalAIChat(Base):
|
||||
_FACTORY_NAME = "LocalAI"
|
||||
|
||||
@ -1403,6 +1641,7 @@ class LiteLLMBase(ABC):
|
||||
"GiteeAI",
|
||||
"302.AI",
|
||||
"Jiekou.AI",
|
||||
"ZHIPU-AI",
|
||||
]
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
@ -1482,6 +1721,7 @@ class LiteLLMBase(ABC):
|
||||
|
||||
def _chat_streamly(self, history, gen_conf, **kwargs):
|
||||
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
reasoning_start = False
|
||||
|
||||
completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf)
|
||||
@ -1525,6 +1765,96 @@ class LiteLLMBase(ABC):
|
||||
|
||||
yield ans, tol
|
||||
|
||||
async def async_chat(self, history, gen_conf, **kwargs):
|
||||
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
||||
if self.model_name.lower().find("qwen3") >= 0:
|
||||
kwargs["extra_body"] = {"enable_thinking": False}
|
||||
|
||||
completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf)
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
response = await litellm.acompletion(
|
||||
**completion_args,
|
||||
drop_params=True,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
|
||||
return "", 0
|
||||
ans = response.choices[0].message.content.strip()
|
||||
if response.choices[0].finish_reason == "length":
|
||||
ans = self._length_stop(ans)
|
||||
|
||||
return ans, total_token_count_from_response(response)
|
||||
except Exception as e:
|
||||
e = await self._exceptions_async(e, attempt)
|
||||
if e:
|
||||
return e, 0
|
||||
|
||||
assert False, "Shouldn't be here."
|
||||
|
||||
async def async_chat_streamly(self, system, history, gen_conf, **kwargs):
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
reasoning_start = False
|
||||
total_tokens = 0
|
||||
|
||||
completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf)
|
||||
stop = kwargs.get("stop")
|
||||
if stop:
|
||||
completion_args["stop"] = stop
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
stream = await litellm.acompletion(
|
||||
**completion_args,
|
||||
drop_params=True,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
async for resp in stream:
|
||||
if not hasattr(resp, "choices") or not resp.choices:
|
||||
continue
|
||||
|
||||
delta = resp.choices[0].delta
|
||||
if not hasattr(delta, "content") or delta.content is None:
|
||||
delta.content = ""
|
||||
|
||||
if kwargs.get("with_reasoning", True) and hasattr(delta, "reasoning_content") and delta.reasoning_content:
|
||||
ans = ""
|
||||
if not reasoning_start:
|
||||
reasoning_start = True
|
||||
ans = "<think>"
|
||||
ans += delta.reasoning_content + "</think>"
|
||||
else:
|
||||
reasoning_start = False
|
||||
ans = delta.content
|
||||
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
tol = num_tokens_from_string(delta.content)
|
||||
total_tokens += tol
|
||||
|
||||
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
|
||||
if finish_reason == "length":
|
||||
if is_chinese(ans):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
|
||||
yield ans
|
||||
yield total_tokens
|
||||
return
|
||||
except Exception as e:
|
||||
e = await self._exceptions_async(e, attempt)
|
||||
if e:
|
||||
yield e
|
||||
yield total_tokens
|
||||
return
|
||||
|
||||
def _length_stop(self, ans):
|
||||
if is_chinese([ans]):
|
||||
return ans + LENGTH_NOTIFICATION_CN
|
||||
@ -1555,6 +1885,21 @@ class LiteLLMBase(ABC):
|
||||
|
||||
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
||||
|
||||
async def _exceptions_async(self, e, attempt) -> str | None:
|
||||
logging.exception("LiteLLMBase async completion")
|
||||
error_code = self._classify_error(e)
|
||||
if attempt == self.max_retries:
|
||||
error_code = LLMErrorCode.ERROR_MAX_RETRIES
|
||||
|
||||
if self._should_retry(error_code):
|
||||
delay = self._get_delay()
|
||||
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
||||
await asyncio.sleep(delay)
|
||||
return None
|
||||
msg = f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
||||
logging.error(f"async_chat_streamly giving up: {msg}")
|
||||
return msg
|
||||
|
||||
def _verbose_tool_use(self, name, args, res):
|
||||
return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user