diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index fbfa7d65e..046e0e274 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import inspect import logging import re from functools import partial @@ -377,7 +378,24 @@ class LLMBundle: return txt return txt[last_think_end + len("") :] + + @staticmethod + def _clean_param(chat_partial, **kwargs): + func = chat_partial.func + sig = inspect.signature(func) + keyword_args = [] + support_var_args = False + for param in sig.parameters.values(): + if param.kind == inspect.Parameter.VAR_KEYWORD or param.kind == inspect.Parameter.VAR_POSITIONAL: + support_var_args = True + elif param.kind == inspect.Parameter.KEYWORD_ONLY: + keyword_args.append(param.name) + use_kwargs = kwargs + if not support_var_args: + use_kwargs = {k: v for k, v in kwargs.items() if k in keyword_args} + return use_kwargs + def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str: if self.langfuse: generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history}) @@ -385,8 +403,9 @@ class LLMBundle: chat_partial = partial(self.mdl.chat, system, history, gen_conf) if self.is_tools and self.mdl.is_tools: chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf) - - txt, used_tokens = chat_partial(**kwargs) + + use_kwargs = self._clean_param(chat_partial, **kwargs) + txt, used_tokens = chat_partial(**use_kwargs) txt = self._remove_reasoning_content(txt) if not self.verbose_tool_use: @@ -410,8 +429,8 @@ class LLMBundle: total_tokens = 0 if self.is_tools and self.mdl.is_tools: chat_partial = partial(self.mdl.chat_streamly_with_tools, system, history, gen_conf) - - for txt in chat_partial(**kwargs): + use_kwargs = self._clean_param(chat_partial, **kwargs) + for txt in chat_partial(**use_kwargs): if isinstance(txt, int): total_tokens = txt if self.langfuse: