diff --git a/Dockerfile b/Dockerfile index 057f347c8..f44b643bf 100644 --- a/Dockerfile +++ b/Dockerfile @@ -105,10 +105,10 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \ apt update && \ arch="$(uname -m)"; \ if [ "$arch" = "arm64" ] || [ "$arch" = "aarch64" ]; then \ - # ARM64 (macOS/Apple Silicon or Linux aarch64) + # ARM64 (macOS/Apple Silicon or Linux aarch64) \ ACCEPT_EULA=Y apt install -y unixodbc-dev msodbcsql18; \ else \ - # x86_64 or others + # x86_64 or others \ ACCEPT_EULA=Y apt install -y unixodbc-dev msodbcsql17; \ fi || \ { echo "Failed to install ODBC driver"; exit 1; } diff --git a/agent/__init__.py b/agent/__init__.py index a42cd9a6d..177b91dd0 100644 --- a/agent/__init__.py +++ b/agent/__init__.py @@ -13,6 +13,3 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -# from beartype.claw import beartype_this_package -# beartype_this_package() diff --git a/agent/component/base.py b/agent/component/base.py index 321907dbe..63bff17f8 100644 --- a/agent/component/base.py +++ b/agent/component/base.py @@ -27,7 +27,6 @@ import pandas as pd from agent import settings from common.connection_utils import timeout - _FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params" _DEPRECATED_PARAMS = "_deprecated_params" _USER_FEEDED_PARAMS = "_user_feeded_params" @@ -97,7 +96,7 @@ class ComponentParamBase(ABC): def _recursive_convert_obj_to_dict(obj): ret_dict = {} if isinstance(obj, dict): - for k,v in obj.items(): + for k, v in obj.items(): if isinstance(v, dict) or (v and type(v).__name__ not in dir(builtins)): ret_dict[k] = _recursive_convert_obj_to_dict(v) else: @@ -253,96 +252,65 @@ class ComponentParamBase(ABC): self._validate_param(attr, validation_json) @staticmethod - def check_string(param, descr): + def check_string(param, description): if type(param).__name__ not in ["str"]: - raise ValueError( - descr + " {} not supported, should be string type".format(param) - ) + raise ValueError(description + " {} not supported, should be string type".format(param)) @staticmethod - def check_empty(param, descr): + def check_empty(param, description): if not param: - raise ValueError( - descr + " does not support empty value." - ) + raise ValueError(description + " does not support empty value.") @staticmethod - def check_positive_integer(param, descr): + def check_positive_integer(param, description): if type(param).__name__ not in ["int", "long"] or param <= 0: - raise ValueError( - descr + " {} not supported, should be positive integer".format(param) - ) + raise ValueError(description + " {} not supported, should be positive integer".format(param)) @staticmethod - def check_positive_number(param, descr): + def check_positive_number(param, description): if type(param).__name__ not in ["float", "int", "long"] or param <= 0: - raise ValueError( - descr + " {} not supported, should be positive numeric".format(param) - ) + raise ValueError(description + " {} not supported, should be positive numeric".format(param)) @staticmethod - def check_nonnegative_number(param, descr): + def check_nonnegative_number(param, description): if type(param).__name__ not in ["float", "int", "long"] or param < 0: - raise ValueError( - descr - + " {} not supported, should be non-negative numeric".format(param) - ) + raise ValueError(description + " {} not supported, should be non-negative numeric".format(param)) @staticmethod - def check_decimal_float(param, descr): + def check_decimal_float(param, description): if type(param).__name__ not in ["float", "int"] or param < 0 or param > 1: - raise ValueError( - descr - + " {} not supported, should be a float number in range [0, 1]".format( - param - ) - ) + raise ValueError(description + " {} not supported, should be a float number in range [0, 1]".format(param)) @staticmethod - def check_boolean(param, descr): + def check_boolean(param, description): if type(param).__name__ != "bool": - raise ValueError( - descr + " {} not supported, should be bool type".format(param) - ) + raise ValueError(description + " {} not supported, should be bool type".format(param)) @staticmethod - def check_open_unit_interval(param, descr): + def check_open_unit_interval(param, description): if type(param).__name__ not in ["float"] or param <= 0 or param >= 1: - raise ValueError( - descr + " should be a numeric number between 0 and 1 exclusively" - ) + raise ValueError(description + " should be a numeric number between 0 and 1 exclusively") @staticmethod - def check_valid_value(param, descr, valid_values): + def check_valid_value(param, description, valid_values): if param not in valid_values: - raise ValueError( - descr - + " {} is not supported, it should be in {}".format(param, valid_values) - ) + raise ValueError(description + " {} is not supported, it should be in {}".format(param, valid_values)) @staticmethod - def check_defined_type(param, descr, types): + def check_defined_type(param, description, types): if type(param).__name__ not in types: - raise ValueError( - descr + " {} not supported, should be one of {}".format(param, types) - ) + raise ValueError(description + " {} not supported, should be one of {}".format(param, types)) @staticmethod - def check_and_change_lower(param, valid_list, descr=""): + def check_and_change_lower(param, valid_list, description=""): if type(param).__name__ != "str": - raise ValueError( - descr - + " {} not supported, should be one of {}".format(param, valid_list) - ) + raise ValueError(description + " {} not supported, should be one of {}".format(param, valid_list)) lower_param = param.lower() if lower_param in valid_list: return lower_param else: - raise ValueError( - descr - + " {} not supported, should be one of {}".format(param, valid_list) - ) + raise ValueError(description + " {} not supported, should be one of {}".format(param, valid_list)) @staticmethod def _greater_equal_than(value, limit): @@ -374,16 +342,16 @@ class ComponentParamBase(ABC): def _not_in(value, wrong_value_list): return value not in wrong_value_list - def _warn_deprecated_param(self, param_name, descr): + def _warn_deprecated_param(self, param_name, description): if self._deprecated_params_set.get(param_name): logging.warning( - f"{descr} {param_name} is deprecated and ignored in this version." + f"{description} {param_name} is deprecated and ignored in this version." ) - def _warn_to_deprecate_param(self, param_name, descr, new_param): + def _warn_to_deprecate_param(self, param_name, description, new_param): if self._deprecated_params_set.get(param_name): logging.warning( - f"{descr} {param_name} will be deprecated in future release; " + f"{description} {param_name} will be deprecated in future release; " f"please use {new_param} instead." ) return True @@ -407,7 +375,7 @@ class ComponentBase(ABC): "params": {} }}""".format(self.component_name, self._param - ) + ) def __init__(self, canvas, id, param: ComponentParamBase): from agent.canvas import Graph # Local import to avoid cyclic dependency @@ -473,14 +441,14 @@ class ComponentBase(ABC): self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time")) return self.output() - @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60))) def _invoke(self, **kwargs): raise NotImplementedError() - def output(self, var_nm: str=None) -> Union[dict[str, Any], Any]: + def output(self, var_nm: str = None) -> Union[dict[str, Any], Any]: if var_nm: return self._param.outputs.get(var_nm, {}).get("value", "") - return {k: o.get("value") for k,o in self._param.outputs.items()} + return {k: o.get("value") for k, o in self._param.outputs.items()} def set_output(self, key: str, value: Any): if key not in self._param.outputs: @@ -491,18 +459,18 @@ class ComponentBase(ABC): return self._param.outputs.get("_ERROR", {}).get("value") def reset(self, only_output=False): - outputs: dict = self._param.outputs # for better performance + outputs: dict = self._param.outputs # for better performance for k in outputs.keys(): outputs[k]["value"] = None if only_output: return - inputs: dict = self._param.inputs # for better performance + inputs: dict = self._param.inputs # for better performance for k in inputs.keys(): inputs[k]["value"] = None self._param.debug_inputs = {} - def get_input(self, key: str=None) -> Union[Any, dict[str, Any]]: + def get_input(self, key: str = None) -> Union[Any, dict[str, Any]]: if key: return self._param.inputs.get(key, {}).get("value") @@ -526,13 +494,13 @@ class ComponentBase(ABC): def get_input_elements_from_text(self, txt: str) -> dict[str, dict[str, str]]: res = {} - for r in re.finditer(self.variable_ref_patt, txt, flags=re.IGNORECASE|re.DOTALL): + for r in re.finditer(self.variable_ref_patt, txt, flags=re.IGNORECASE | re.DOTALL): exp = r.group(1) - cpn_id, var_nm = exp.split("@") if exp.find("@")>0 else ("", exp) + cpn_id, var_nm = exp.split("@") if exp.find("@") > 0 else ("", exp) res[exp] = { - "name": (self._canvas.get_component_name(cpn_id) +f"@{var_nm}") if cpn_id else exp, + "name": (self._canvas.get_component_name(cpn_id) + f"@{var_nm}") if cpn_id else exp, "value": self._canvas.get_variable_value(exp), - "_retrival": self._canvas.get_variable_value(f"{cpn_id}@_references") if cpn_id else None, + "_retrieval": self._canvas.get_variable_value(f"{cpn_id}@_references") if cpn_id else None, "_cpn_id": cpn_id } return res @@ -583,6 +551,7 @@ class ComponentBase(ABC): for n, v in kv.items(): def repl(_match, val=v): return str(val) if val is not None else "" + content = re.sub( r"\{%s\}" % re.escape(n), repl, diff --git a/agent/test/dsl_examples/categorize_and_agent_with_tavily.json b/agent/test/dsl_examples/categorize_and_agent_with_tavily.json index 7d9567446..49738f14d 100644 --- a/agent/test/dsl_examples/categorize_and_agent_with_tavily.json +++ b/agent/test/dsl_examples/categorize_and_agent_with_tavily.json @@ -75,7 +75,7 @@ }, "history": [], "path": [], - "retrival": {"chunks": [], "doc_aggs": []}, + "retrieval": {"chunks": [], "doc_aggs": []}, "globals": { "sys.query": "", "sys.user_id": "", diff --git a/agent/test/dsl_examples/iteration.json b/agent/test/dsl_examples/iteration.json index dd4448423..dc976aa8b 100644 --- a/agent/test/dsl_examples/iteration.json +++ b/agent/test/dsl_examples/iteration.json @@ -82,7 +82,7 @@ }, "history": [], "path": [], - "retrival": {"chunks": [], "doc_aggs": []}, + "retrieval": {"chunks": [], "doc_aggs": []}, "globals": { "sys.query": "", "sys.user_id": "", diff --git a/agent/test/dsl_examples/retrieval_and_generate.json b/agent/test/dsl_examples/retrieval_and_generate.json index 9f9f9bac4..b392962e2 100644 --- a/agent/test/dsl_examples/retrieval_and_generate.json +++ b/agent/test/dsl_examples/retrieval_and_generate.json @@ -51,7 +51,7 @@ }, "history": [], "path": [], - "retrival": {"chunks": [], "doc_aggs": []}, + "retrieval": {"chunks": [], "doc_aggs": []}, "globals": { "sys.query": "", "sys.user_id": "", diff --git a/agent/test/dsl_examples/retrieval_categorize_and_generate.json b/agent/test/dsl_examples/retrieval_categorize_and_generate.json index c506b9a6b..ed6866ae5 100644 --- a/agent/test/dsl_examples/retrieval_categorize_and_generate.json +++ b/agent/test/dsl_examples/retrieval_categorize_and_generate.json @@ -85,7 +85,7 @@ }, "history": [], "path": [], - "retrival": {"chunks": [], "doc_aggs": []}, + "retrieval": {"chunks": [], "doc_aggs": []}, "globals": { "sys.query": "", "sys.user_id": "", diff --git a/agent/test/dsl_examples/tavily_and_generate.json b/agent/test/dsl_examples/tavily_and_generate.json index f2f79b4b7..caa10d155 100644 --- a/agent/test/dsl_examples/tavily_and_generate.json +++ b/agent/test/dsl_examples/tavily_and_generate.json @@ -45,7 +45,7 @@ }, "history": [], "path": [], - "retrival": {"chunks": [], "doc_aggs": []}, + "retrieval": {"chunks": [], "doc_aggs": []}, "globals": { "sys.query": "", "sys.user_id": "", diff --git a/rag/app/book.py b/rag/app/book.py index fe079783c..b392d4139 100644 --- a/rag/app/book.py +++ b/rag/app/book.py @@ -166,9 +166,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, sections = [s.split("@") for s, _ in sections] sections = [(pr[0], "@" + pr[1]) if len(pr) == 2 else (pr[0], '') for pr in sections ] chunks = naive_merge( - sections, kwargs.get( - "chunk_token_num", 256), kwargs.get( - "delimer", "\n。;!?")) + sections, + parser_config.get("chunk_token_num", 256), + parser_config.get("delimiter", "\n。;!?") + ) # is it English # is_english(random_choices([t for t, _ in sections], k=218))