mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refa: chat with tools. (#8210)
### What problem does this PR solve? ### Type of change - [x] Refactoring
This commit is contained in:
@ -26,6 +26,7 @@ from http import HTTPStatus
|
|||||||
from typing import Any, Protocol
|
from typing import Any, Protocol
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
|
import json_repair
|
||||||
import openai
|
import openai
|
||||||
import requests
|
import requests
|
||||||
from dashscope import Generation
|
from dashscope import Generation
|
||||||
@ -67,11 +68,12 @@ class Base(ABC):
|
|||||||
# Configure retry parameters
|
# Configure retry parameters
|
||||||
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
|
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
|
||||||
self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0)))
|
self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0)))
|
||||||
|
self.max_rounds = kwargs.get("max_rounds", 5)
|
||||||
self.is_tools = False
|
self.is_tools = False
|
||||||
|
|
||||||
def _get_delay(self, attempt):
|
def _get_delay(self):
|
||||||
"""Calculate retry delay time"""
|
"""Calculate retry delay time"""
|
||||||
return self.base_delay * (2**attempt) + random.uniform(0, 0.5)
|
return self.base_delay + random.uniform(0, 0.5)
|
||||||
|
|
||||||
def _classify_error(self, error):
|
def _classify_error(self, error):
|
||||||
"""Classify error based on error message content"""
|
"""Classify error based on error message content"""
|
||||||
@ -116,6 +118,29 @@ class Base(ABC):
|
|||||||
ans += LENGTH_NOTIFICATION_EN
|
ans += LENGTH_NOTIFICATION_EN
|
||||||
return ans, self.total_token_count(response)
|
return ans, self.total_token_count(response)
|
||||||
|
|
||||||
|
def _length_stop(self, ans):
|
||||||
|
if is_chinese([ans]):
|
||||||
|
return ans + LENGTH_NOTIFICATION_CN
|
||||||
|
return ans + LENGTH_NOTIFICATION_EN
|
||||||
|
|
||||||
|
def _exceptions(self, e, attempt):
|
||||||
|
logging.exception("OpenAI cat_with_tools")
|
||||||
|
# Classify the error
|
||||||
|
error_code = self._classify_error(e)
|
||||||
|
|
||||||
|
# Check if it's a rate limit error or server error and not the last attempt
|
||||||
|
should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries
|
||||||
|
|
||||||
|
if should_retry:
|
||||||
|
delay = self._get_delay()
|
||||||
|
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
||||||
|
time.sleep(delay)
|
||||||
|
else:
|
||||||
|
# For non-rate limit errors or the last attempt, return an error message
|
||||||
|
if attempt == self.max_retries:
|
||||||
|
error_code = ERROR_MAX_RETRIES
|
||||||
|
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
||||||
|
|
||||||
def bind_tools(self, toolcall_session, tools):
|
def bind_tools(self, toolcall_session, tools):
|
||||||
if not (toolcall_session and tools):
|
if not (toolcall_session and tools):
|
||||||
return
|
return
|
||||||
@ -124,76 +149,48 @@ class Base(ABC):
|
|||||||
self.tools = tools
|
self.tools = tools
|
||||||
|
|
||||||
def chat_with_tools(self, system: str, history: list, gen_conf: dict):
|
def chat_with_tools(self, system: str, history: list, gen_conf: dict):
|
||||||
if "max_tokens" in gen_conf:
|
gen_conf = self._clean_conf()
|
||||||
del gen_conf["max_tokens"]
|
|
||||||
|
|
||||||
tools = self.tools
|
|
||||||
|
|
||||||
if system:
|
if system:
|
||||||
history.insert(0, {"role": "system", "content": system})
|
history.insert(0, {"role": "system", "content": system})
|
||||||
|
|
||||||
ans = ""
|
ans = ""
|
||||||
tk_count = 0
|
tk_count = 0
|
||||||
|
hist = deepcopy(history)
|
||||||
# Implement exponential backoff retry strategy
|
# Implement exponential backoff retry strategy
|
||||||
for attempt in range(self.max_retries):
|
for attempt in range(self.max_retries+1):
|
||||||
try:
|
history = hist
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf)
|
for _ in range(self.max_rounds*2):
|
||||||
|
try:
|
||||||
assistant_output = response.choices[0].message
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf)
|
||||||
if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
|
|
||||||
ans += "<think>" + ans + "</think>"
|
|
||||||
ans += response.choices[0].message.content
|
|
||||||
|
|
||||||
if not response.choices[0].message.tool_calls:
|
|
||||||
tk_count += self.total_token_count(response)
|
tk_count += self.total_token_count(response)
|
||||||
if response.choices[0].finish_reason == "length":
|
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
|
||||||
if is_chinese([ans]):
|
raise Exception("500 response structure error.")
|
||||||
ans += LENGTH_NOTIFICATION_CN
|
|
||||||
else:
|
|
||||||
ans += LENGTH_NOTIFICATION_EN
|
|
||||||
return ans, tk_count
|
|
||||||
|
|
||||||
tk_count += self.total_token_count(response)
|
if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
|
||||||
history.append(assistant_output)
|
if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
|
||||||
|
ans += "<think>" + response.choices[0].message.reasoning_content + "</think>"
|
||||||
|
|
||||||
for tool_call in response.choices[0].message.tool_calls:
|
ans += response.choices[0].message.content
|
||||||
name = tool_call.function.name
|
if response.choices[0].finish_reason == "length":
|
||||||
args = json.loads(tool_call.function.arguments)
|
ans = self._length_stop(ans)
|
||||||
|
|
||||||
tool_response = self.toolcall_session.tool_call(name, args)
|
return ans, tk_count
|
||||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
|
|
||||||
|
|
||||||
final_response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf)
|
for tool_call in response.choices[0].message.tool_calls:
|
||||||
assistant_output = final_response.choices[0].message
|
name = tool_call.function.name
|
||||||
if "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
|
try:
|
||||||
ans += "<think>" + ans + "</think>"
|
args = json_repair.loads(tool_call.function.arguments)
|
||||||
ans += final_response.choices[0].message.content
|
tool_response = self.toolcall_session.tool_call(name, args)
|
||||||
if final_response.choices[0].finish_reason == "length":
|
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
|
||||||
tk_count += self.total_token_count(response)
|
except Exception as e:
|
||||||
if is_chinese([ans]):
|
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
||||||
ans += LENGTH_NOTIFICATION_CN
|
|
||||||
else:
|
|
||||||
ans += LENGTH_NOTIFICATION_EN
|
|
||||||
return ans, tk_count
|
|
||||||
return ans, tk_count
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception("OpenAI cat_with_tools")
|
|
||||||
# Classify the error
|
|
||||||
error_code = self._classify_error(e)
|
|
||||||
|
|
||||||
# Check if it's a rate limit error or server error and not the last attempt
|
except Exception as e:
|
||||||
should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries - 1
|
e = self._exceptions(e, attempt)
|
||||||
|
if e:
|
||||||
if should_retry:
|
return e, tk_count
|
||||||
delay = self._get_delay(attempt)
|
assert False, "Shouldn't be here."
|
||||||
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
|
||||||
time.sleep(delay)
|
|
||||||
else:
|
|
||||||
# For non-rate limit errors or the last attempt, return an error message
|
|
||||||
if attempt == self.max_retries - 1:
|
|
||||||
error_code = ERROR_MAX_RETRIES
|
|
||||||
return f"{ERROR_PREFIX}: {error_code} - {str(e)}", 0
|
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf):
|
def chat(self, system, history, gen_conf):
|
||||||
if system:
|
if system:
|
||||||
@ -201,26 +198,14 @@ class Base(ABC):
|
|||||||
gen_conf = self._clean_conf(gen_conf)
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
|
|
||||||
# Implement exponential backoff retry strategy
|
# Implement exponential backoff retry strategy
|
||||||
for attempt in range(self.max_retries):
|
for attempt in range(self.max_retries+1):
|
||||||
try:
|
try:
|
||||||
return self._chat(history, gen_conf)
|
return self._chat(history, gen_conf)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception("chat_model.Base.chat got exception")
|
e = self._exceptions(e, attempt)
|
||||||
# Classify the error
|
if e:
|
||||||
error_code = self._classify_error(e)
|
return e, 0
|
||||||
|
assert False, "Shouldn't be here."
|
||||||
# Check if it's a rate limit error or server error and not the last attempt
|
|
||||||
should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries - 1
|
|
||||||
|
|
||||||
if should_retry:
|
|
||||||
delay = self._get_delay(attempt)
|
|
||||||
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
|
||||||
time.sleep(delay)
|
|
||||||
else:
|
|
||||||
# For non-rate limit errors or the last attempt, return an error message
|
|
||||||
if attempt == self.max_retries - 1:
|
|
||||||
error_code = ERROR_MAX_RETRIES
|
|
||||||
return f"{ERROR_PREFIX}: {error_code} - {str(e)}", 0
|
|
||||||
|
|
||||||
def _wrap_toolcall_message(self, stream):
|
def _wrap_toolcall_message(self, stream):
|
||||||
final_tool_calls = {}
|
final_tool_calls = {}
|
||||||
@ -241,41 +226,48 @@ class Base(ABC):
|
|||||||
del gen_conf["max_tokens"]
|
del gen_conf["max_tokens"]
|
||||||
|
|
||||||
tools = self.tools
|
tools = self.tools
|
||||||
|
|
||||||
if system:
|
if system:
|
||||||
history.insert(0, {"role": "system", "content": system})
|
history.insert(0, {"role": "system", "content": system})
|
||||||
|
|
||||||
ans = ""
|
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
reasoning_start = False
|
hist = deepcopy(history)
|
||||||
finish_completion = False
|
# Implement exponential backoff retry strategy
|
||||||
final_tool_calls = {}
|
for attempt in range(self.max_retries+1):
|
||||||
try:
|
history = hist
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
|
for _ in range(self.max_rounds*2):
|
||||||
while not finish_completion:
|
reasoning_start = False
|
||||||
for resp in response:
|
try:
|
||||||
if resp.choices[0].delta.tool_calls:
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
|
||||||
for tool_call in resp.choices[0].delta.tool_calls or []:
|
final_tool_calls = {}
|
||||||
index = tool_call.index
|
answer = ""
|
||||||
|
for resp in response:
|
||||||
|
if resp.choices[0].delta.tool_calls:
|
||||||
|
for tool_call in resp.choices[0].delta.tool_calls or []:
|
||||||
|
index = tool_call.index
|
||||||
|
|
||||||
if index not in final_tool_calls:
|
if index not in final_tool_calls:
|
||||||
final_tool_calls[index] = tool_call
|
final_tool_calls[index] = tool_call
|
||||||
else:
|
else:
|
||||||
final_tool_calls[index].function.arguments += tool_call.function.arguments
|
final_tool_calls[index].function.arguments += tool_call.function.arguments
|
||||||
else:
|
|
||||||
if not resp.choices:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
|
||||||
|
raise Exception("500 response structure error.")
|
||||||
|
|
||||||
if not resp.choices[0].delta.content:
|
if not resp.choices[0].delta.content:
|
||||||
resp.choices[0].delta.content = ""
|
resp.choices[0].delta.content = ""
|
||||||
|
|
||||||
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
|
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
|
||||||
ans = ""
|
ans = ""
|
||||||
if not reasoning_start:
|
if not reasoning_start:
|
||||||
reasoning_start = True
|
reasoning_start = True
|
||||||
ans = "<think>"
|
ans = "<think>"
|
||||||
ans += resp.choices[0].delta.reasoning_content + "</think>"
|
ans += resp.choices[0].delta.reasoning_content + "</think>"
|
||||||
|
yield ans
|
||||||
else:
|
else:
|
||||||
reasoning_start = False
|
reasoning_start = False
|
||||||
ans = resp.choices[0].delta.content
|
answer += resp.choices[0].delta.content
|
||||||
|
yield resp.choices[0].delta.content
|
||||||
|
|
||||||
tol = self.total_token_count(resp)
|
tol = self.total_token_count(resp)
|
||||||
if not tol:
|
if not tol:
|
||||||
@ -283,18 +275,18 @@ class Base(ABC):
|
|||||||
else:
|
else:
|
||||||
total_tokens += tol
|
total_tokens += tol
|
||||||
|
|
||||||
finish_reason = resp.choices[0].finish_reason
|
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
|
||||||
if finish_reason == "tool_calls" and final_tool_calls:
|
if finish_reason == "length":
|
||||||
for tool_call in final_tool_calls.values():
|
yield self._length_stop("")
|
||||||
name = tool_call.function.name
|
|
||||||
try:
|
|
||||||
args = json.loads(tool_call.function.arguments)
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
|
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
|
||||||
finish_completion = True
|
|
||||||
break
|
|
||||||
|
|
||||||
|
if answer:
|
||||||
|
yield total_tokens
|
||||||
|
return
|
||||||
|
|
||||||
|
for tool_call in final_tool_calls.values():
|
||||||
|
name = tool_call.function.name
|
||||||
|
try:
|
||||||
|
args = json_repair.loads(tool_call.function.arguments)
|
||||||
tool_response = self.toolcall_session.tool_call(name, args)
|
tool_response = self.toolcall_session.tool_call(name, args)
|
||||||
history.append(
|
history.append(
|
||||||
{
|
{
|
||||||
@ -313,26 +305,16 @@ class Base(ABC):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
|
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
|
||||||
final_tool_calls = {}
|
except Exception as e:
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
|
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
|
||||||
continue
|
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
||||||
if finish_reason == "length":
|
except Exception as e:
|
||||||
if is_chinese(ans):
|
e = self._exceptions(e, attempt)
|
||||||
ans += LENGTH_NOTIFICATION_CN
|
if e:
|
||||||
else:
|
yield total_tokens
|
||||||
ans += LENGTH_NOTIFICATION_EN
|
return
|
||||||
return ans, total_tokens
|
|
||||||
if finish_reason == "stop":
|
|
||||||
finish_completion = True
|
|
||||||
yield ans
|
|
||||||
break
|
|
||||||
yield ans
|
|
||||||
continue
|
|
||||||
|
|
||||||
except openai.APIError as e:
|
assert False, "Shouldn't be here."
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
|
||||||
|
|
||||||
yield total_tokens
|
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf):
|
def chat_streamly(self, system, history, gen_conf):
|
||||||
if system:
|
if system:
|
||||||
@ -636,49 +618,21 @@ class QWenChat(Base):
|
|||||||
return "".join(result_list[:-1]), result_list[-1]
|
return "".join(result_list[:-1]), result_list[-1]
|
||||||
|
|
||||||
def _chat(self, history, gen_conf):
|
def _chat(self, history, gen_conf):
|
||||||
tk_count = 0
|
|
||||||
if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]:
|
if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]:
|
||||||
try:
|
return super()._chat(history, gen_conf)
|
||||||
response = super()._chat(history, gen_conf)
|
response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf)
|
||||||
return response
|
ans = ""
|
||||||
except Exception as e:
|
tk_count = 0
|
||||||
error_msg = str(e).lower()
|
if response.status_code == HTTPStatus.OK:
|
||||||
if "invalid_parameter_error" in error_msg and "only support stream mode" in error_msg:
|
ans += response.output.choices[0]["message"]["content"]
|
||||||
return self._simulate_one_shot_from_stream(history, gen_conf)
|
tk_count += self.total_token_count(response)
|
||||||
|
if response.output.choices[0].get("finish_reason", "") == "length":
|
||||||
|
if is_chinese([ans]):
|
||||||
|
ans += LENGTH_NOTIFICATION_CN
|
||||||
else:
|
else:
|
||||||
return "**ERROR**: " + str(e), tk_count
|
ans += LENGTH_NOTIFICATION_EN
|
||||||
|
return ans, tk_count
|
||||||
try:
|
return "**ERROR**: " + response.message, tk_count
|
||||||
ans = ""
|
|
||||||
response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf)
|
|
||||||
if response.status_code == HTTPStatus.OK:
|
|
||||||
ans += response.output.choices[0]["message"]["content"]
|
|
||||||
tk_count += self.total_token_count(response)
|
|
||||||
if response.output.choices[0].get("finish_reason", "") == "length":
|
|
||||||
if is_chinese([ans]):
|
|
||||||
ans += LENGTH_NOTIFICATION_CN
|
|
||||||
else:
|
|
||||||
ans += LENGTH_NOTIFICATION_EN
|
|
||||||
return ans, tk_count
|
|
||||||
return "**ERROR**: " + response.message, tk_count
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = str(e).lower()
|
|
||||||
if "invalid_parameter_error" in error_msg and "only support stream mode" in error_msg:
|
|
||||||
return self._simulate_one_shot_from_stream(history, gen_conf)
|
|
||||||
else:
|
|
||||||
return "**ERROR**: " + str(e), tk_count
|
|
||||||
|
|
||||||
def _simulate_one_shot_from_stream(self, history, gen_conf):
|
|
||||||
"""
|
|
||||||
Handles models that require streaming output but need one-shot response.
|
|
||||||
"""
|
|
||||||
g = self._chat_streamly("", history, gen_conf, incremental_output=True)
|
|
||||||
result_list = list(g)
|
|
||||||
error_msg_list = [item for item in result_list if str(item).find("**ERROR**") >= 0]
|
|
||||||
if len(error_msg_list) > 0:
|
|
||||||
return "**ERROR**: " + "".join(error_msg_list), 0
|
|
||||||
else:
|
|
||||||
return "".join(result_list[:-1]), result_list[-1]
|
|
||||||
|
|
||||||
def _wrap_toolcall_message(self, old_message, message):
|
def _wrap_toolcall_message(self, old_message, message):
|
||||||
if not old_message:
|
if not old_message:
|
||||||
@ -971,10 +925,10 @@ class LocalAIChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class LocalLLM(Base):
|
class LocalLLM(Base):
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||||
from jina import Client
|
from jina import Client
|
||||||
|
|
||||||
self.client = Client(port=12345, protocol="grpc", asyncio=True)
|
self.client = Client(port=12345, protocol="grpc", asyncio=True)
|
||||||
|
|
||||||
def _prepare_prompt(self, system, history, gen_conf):
|
def _prepare_prompt(self, system, history, gen_conf):
|
||||||
@ -1031,7 +985,13 @@ class VolcEngineChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class MiniMaxChat(Base):
|
class MiniMaxChat(Base):
|
||||||
def __init__(self, key, model_name, base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
key,
|
||||||
|
model_name,
|
||||||
|
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||||
|
|
||||||
if not base_url:
|
if not base_url:
|
||||||
@ -1263,7 +1223,6 @@ class GeminiChat(Base):
|
|||||||
|
|
||||||
def _chat(self, history, gen_conf):
|
def _chat(self, history, gen_conf):
|
||||||
from google.generativeai.types import content_types
|
from google.generativeai.types import content_types
|
||||||
|
|
||||||
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
|
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
|
||||||
hist = []
|
hist = []
|
||||||
for item in history:
|
for item in history:
|
||||||
@ -1921,4 +1880,4 @@ class GPUStackChat(Base):
|
|||||||
if not base_url:
|
if not base_url:
|
||||||
raise ValueError("Local llm url cannot be None")
|
raise ValueError("Local llm url cannot be None")
|
||||||
base_url = urljoin(base_url, "v1")
|
base_url = urljoin(base_url, "v1")
|
||||||
super().__init__(key, model_name, base_url, **kwargs)
|
super().__init__(key, model_name, base_url, **kwargs)
|
||||||
@ -119,7 +119,7 @@ def kb_prompt(kbinfos, max_tokens):
|
|||||||
doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []})
|
doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []})
|
||||||
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
|
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
|
||||||
cnt = f"---\nID: {i}\n" + (f"URL: {ck['url']}\n" if "url" in ck else "")
|
cnt = f"---\nID: {i}\n" + (f"URL: {ck['url']}\n" if "url" in ck else "")
|
||||||
cnt += ck["content_with_weight"]
|
cnt += re.sub(r"( style=\"[^\"]+\"|</?(html|body|head|title)>|<!DOCTYPE html>)", " ", ck["content_with_weight"], flags=re.DOTALL|re.IGNORECASE)
|
||||||
doc2chunks[ck["docnm_kwd"]]["chunks"].append(cnt)
|
doc2chunks[ck["docnm_kwd"]]["chunks"].append(cnt)
|
||||||
doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {})
|
doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {})
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user