mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refa: cleanup synchronous functions in chat_model and implement synchronization for conversation and dialog chats (#11779)
### What problem does this PR solve? Cleanup synchronous functions in chat_model and implement synchronization for conversation and dialog chats. ### Type of change - [x] Refactoring - [x] Performance Improvement
This commit is contained in:
@ -16,15 +16,17 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import queue
|
||||
import re
|
||||
import threading
|
||||
from common.token_utils import num_tokens_from_string
|
||||
from functools import partial
|
||||
from typing import Generator
|
||||
from common.constants import LLMType
|
||||
|
||||
from api.db.db_models import LLM
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
|
||||
from common.constants import LLMType
|
||||
from common.token_utils import num_tokens_from_string
|
||||
|
||||
|
||||
class LLMService(CommonService):
|
||||
@ -33,6 +35,7 @@ class LLMService(CommonService):
|
||||
|
||||
def get_init_tenant_llm(user_id):
|
||||
from common import settings
|
||||
|
||||
tenant_llm = []
|
||||
|
||||
model_configs = {
|
||||
@ -193,7 +196,7 @@ class LLMBundle(LLM4Tenant):
|
||||
generation = self.langfuse.start_generation(
|
||||
trace_context=self.trace_context,
|
||||
name="stream_transcription",
|
||||
metadata={"model": self.llm_name}
|
||||
metadata={"model": self.llm_name},
|
||||
)
|
||||
final_text = ""
|
||||
used_tokens = 0
|
||||
@ -217,32 +220,34 @@ class LLMBundle(LLM4Tenant):
|
||||
if self.langfuse:
|
||||
generation.update(
|
||||
output={"output": final_text},
|
||||
usage_details={"total_tokens": used_tokens}
|
||||
usage_details={"total_tokens": used_tokens},
|
||||
)
|
||||
generation.end()
|
||||
|
||||
return
|
||||
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="stream_transcription", metadata={"model": self.llm_name})
|
||||
full_text, used_tokens = mdl.transcription(audio)
|
||||
if not TenantLLMService.increase_usage(
|
||||
self.tenant_id, self.llm_type, used_tokens
|
||||
):
|
||||
logging.error(
|
||||
f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}"
|
||||
generation = self.langfuse.start_generation(
|
||||
trace_context=self.trace_context,
|
||||
name="stream_transcription",
|
||||
metadata={"model": self.llm_name},
|
||||
)
|
||||
|
||||
full_text, used_tokens = mdl.transcription(audio)
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||
logging.error(f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}")
|
||||
|
||||
if self.langfuse:
|
||||
generation.update(
|
||||
output={"output": full_text},
|
||||
usage_details={"total_tokens": used_tokens}
|
||||
usage_details={"total_tokens": used_tokens},
|
||||
)
|
||||
generation.end()
|
||||
|
||||
yield {
|
||||
"event": "final",
|
||||
"text": full_text,
|
||||
"streaming": False
|
||||
"streaming": False,
|
||||
}
|
||||
|
||||
def tts(self, text: str) -> Generator[bytes, None, None]:
|
||||
@ -289,61 +294,79 @@ class LLMBundle(LLM4Tenant):
|
||||
return kwargs
|
||||
else:
|
||||
return {k: v for k, v in kwargs.items() if k in allowed_params}
|
||||
|
||||
def _run_coroutine_sync(self, coro):
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
|
||||
result_queue: queue.Queue = queue.Queue()
|
||||
|
||||
def runner():
|
||||
try:
|
||||
result_queue.put((True, asyncio.run(coro)))
|
||||
except Exception as e:
|
||||
result_queue.put((False, e))
|
||||
|
||||
thread = threading.Thread(target=runner, daemon=True)
|
||||
thread.start()
|
||||
thread.join()
|
||||
|
||||
success, value = result_queue.get_nowait()
|
||||
if success:
|
||||
return value
|
||||
raise value
|
||||
|
||||
def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str:
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
|
||||
return self._run_coroutine_sync(self.async_chat(system, history, gen_conf, **kwargs))
|
||||
|
||||
chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
|
||||
if self.is_tools and self.mdl.is_tools:
|
||||
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
|
||||
def _sync_from_async_stream(self, async_gen_fn, *args, **kwargs):
|
||||
result_queue: queue.Queue = queue.Queue()
|
||||
|
||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||
txt, used_tokens = chat_partial(**use_kwargs)
|
||||
txt = self._remove_reasoning_content(txt)
|
||||
def runner():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
if not self.verbose_tool_use:
|
||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||
async def consume():
|
||||
try:
|
||||
async for item in async_gen_fn(*args, **kwargs):
|
||||
result_queue.put(item)
|
||||
except Exception as e:
|
||||
result_queue.put(e)
|
||||
finally:
|
||||
result_queue.put(StopIteration)
|
||||
|
||||
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
||||
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
||||
loop.run_until_complete(consume())
|
||||
loop.close()
|
||||
|
||||
if self.langfuse:
|
||||
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
|
||||
generation.end()
|
||||
threading.Thread(target=runner, daemon=True).start()
|
||||
|
||||
return txt
|
||||
while True:
|
||||
item = result_queue.get()
|
||||
if item is StopIteration:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
|
||||
def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
|
||||
|
||||
ans = ""
|
||||
chat_partial = partial(self.mdl.chat_streamly, system, history, gen_conf)
|
||||
total_tokens = 0
|
||||
if self.is_tools and self.mdl.is_tools:
|
||||
chat_partial = partial(self.mdl.chat_streamly_with_tools, system, history, gen_conf)
|
||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||
for txt in chat_partial(**use_kwargs):
|
||||
for txt in self._sync_from_async_stream(self.async_chat_streamly, system, history, gen_conf, **kwargs):
|
||||
if isinstance(txt, int):
|
||||
total_tokens = txt
|
||||
if self.langfuse:
|
||||
generation.update(output={"output": ans})
|
||||
generation.end()
|
||||
break
|
||||
|
||||
if txt.endswith("</think>"):
|
||||
ans = ans[: -len("</think>")]
|
||||
ans = txt[: -len("</think>")]
|
||||
continue
|
||||
|
||||
if not self.verbose_tool_use:
|
||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||
|
||||
ans += txt
|
||||
# cancatination has beend done in async_chat_streamly
|
||||
ans = txt
|
||||
yield ans
|
||||
|
||||
if total_tokens > 0:
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
||||
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
||||
|
||||
def _bridge_sync_stream(self, gen):
|
||||
loop = asyncio.get_running_loop()
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
@ -352,7 +375,7 @@ class LLMBundle(LLM4Tenant):
|
||||
try:
|
||||
for item in gen:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, item)
|
||||
except Exception as e: # pragma: no cover
|
||||
except Exception as e:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, e)
|
||||
finally:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
|
||||
@ -361,18 +384,27 @@ class LLMBundle(LLM4Tenant):
|
||||
return queue
|
||||
|
||||
async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
||||
chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
|
||||
if self.is_tools and self.mdl.is_tools and hasattr(self.mdl, "chat_with_tools"):
|
||||
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
|
||||
if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_with_tools"):
|
||||
base_fn = self.mdl.async_chat_with_tools
|
||||
elif hasattr(self.mdl, "async_chat"):
|
||||
base_fn = self.mdl.async_chat
|
||||
else:
|
||||
raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
|
||||
|
||||
generation = None
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
|
||||
|
||||
chat_partial = partial(base_fn, system, history, gen_conf)
|
||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||
|
||||
if hasattr(self.mdl, "async_chat_with_tools") and self.is_tools and self.mdl.is_tools:
|
||||
txt, used_tokens = await self.mdl.async_chat_with_tools(system, history, gen_conf, **use_kwargs)
|
||||
elif hasattr(self.mdl, "async_chat"):
|
||||
txt, used_tokens = await self.mdl.async_chat(system, history, gen_conf, **use_kwargs)
|
||||
else:
|
||||
txt, used_tokens = await asyncio.to_thread(chat_partial, **use_kwargs)
|
||||
try:
|
||||
txt, used_tokens = await chat_partial(**use_kwargs)
|
||||
except Exception as e:
|
||||
if generation:
|
||||
generation.update(output={"error": str(e)})
|
||||
generation.end()
|
||||
raise
|
||||
|
||||
txt = self._remove_reasoning_content(txt)
|
||||
if not self.verbose_tool_use:
|
||||
@ -381,49 +413,51 @@ class LLMBundle(LLM4Tenant):
|
||||
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
||||
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
||||
|
||||
if generation:
|
||||
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
|
||||
generation.end()
|
||||
|
||||
return txt
|
||||
|
||||
async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
||||
total_tokens = 0
|
||||
ans = ""
|
||||
if self.is_tools and self.mdl.is_tools:
|
||||
if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_streamly_with_tools"):
|
||||
stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
|
||||
else:
|
||||
elif hasattr(self.mdl, "async_chat_streamly"):
|
||||
stream_fn = getattr(self.mdl, "async_chat_streamly", None)
|
||||
else:
|
||||
raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
|
||||
|
||||
generation = None
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
|
||||
|
||||
if stream_fn:
|
||||
chat_partial = partial(stream_fn, system, history, gen_conf)
|
||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||
async for txt in chat_partial(**use_kwargs):
|
||||
if isinstance(txt, int):
|
||||
total_tokens = txt
|
||||
break
|
||||
try:
|
||||
async for txt in chat_partial(**use_kwargs):
|
||||
if isinstance(txt, int):
|
||||
total_tokens = txt
|
||||
break
|
||||
|
||||
if txt.endswith("</think>"):
|
||||
ans = ans[: -len("</think>")]
|
||||
if txt.endswith("</think>"):
|
||||
ans = ans[: -len("</think>")]
|
||||
|
||||
if not self.verbose_tool_use:
|
||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||
if not self.verbose_tool_use:
|
||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||
|
||||
ans += txt
|
||||
yield ans
|
||||
ans += txt
|
||||
yield ans
|
||||
except Exception as e:
|
||||
if generation:
|
||||
generation.update(output={"error": str(e)})
|
||||
generation.end()
|
||||
raise
|
||||
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
||||
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
||||
if generation:
|
||||
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
|
||||
generation.end()
|
||||
return
|
||||
|
||||
chat_partial = partial(self.mdl.chat_streamly_with_tools if (self.is_tools and self.mdl.is_tools) else self.mdl.chat_streamly, system, history, gen_conf)
|
||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||
queue = self._bridge_sync_stream(chat_partial(**use_kwargs))
|
||||
while True:
|
||||
item = await queue.get()
|
||||
if item is StopAsyncIteration:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
if isinstance(item, int):
|
||||
total_tokens = item
|
||||
break
|
||||
yield item
|
||||
|
||||
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
||||
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
||||
|
||||
Reference in New Issue
Block a user