diff --git a/agent/canvas.py b/agent/canvas.py index ffa67c73d..81c71b2e2 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -153,6 +153,16 @@ class Graph: def get_tenant_id(self): return self._tenant_id + def get_variable_value(self, exp: str) -> Any: + exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}") + if exp.find("@") < 0: + return self.globals[exp] + cpn_id, var_nm = exp.split("@") + cpn = self.get_component(cpn_id) + if not cpn: + raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'") + return cpn["obj"].output(var_nm) + class Canvas(Graph): @@ -406,16 +416,6 @@ class Canvas(Graph): return False return True - def get_variable_value(self, exp: str) -> Any: - exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}") - if exp.find("@") < 0: - return self.globals[exp] - cpn_id, var_nm = exp.split("@") - cpn = self.get_component(cpn_id) - if not cpn: - raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'") - return cpn["obj"].output(var_nm) - def get_history(self, window_size): convs = [] if window_size <= 0: diff --git a/agent/component/llm.py b/agent/component/llm.py index 9db894305..a378ad0ba 100644 --- a/agent/component/llm.py +++ b/agent/component/llm.py @@ -102,6 +102,8 @@ class LLM(ComponentBase): def get_input_elements(self) -> dict[str, Any]: res = self.get_input_elements_from_text(self._param.sys_prompt) + if isinstance(self._param.prompts, str): + self._param.prompts = [{"role": "user", "content": self._param.prompts}] for prompt in self._param.prompts: d = self.get_input_elements_from_text(prompt["content"]) res.update(d) @@ -114,6 +116,8 @@ class LLM(ComponentBase): self._param.sys_prompt += txt def _sys_prompt_and_msg(self, msg, args): + if isinstance(self._param.prompts, str): + self._param.prompts = [{"role": "user", "content": self._param.prompts}] for p in self._param.prompts: if msg and msg[-1]["role"] == p["role"]: continue diff --git a/rag/flow/extractor/extractor.py b/rag/flow/extractor/extractor.py index 242da56a8..2fdec438e 100644 --- a/rag/flow/extractor/extractor.py +++ b/rag/flow/extractor/extractor.py @@ -14,9 +14,10 @@ # limitations under the License. import random from agent.component.llm import LLMParam, LLM +from rag.flow.base import ProcessBase, ProcessParamBase -class ExtractorParam(LLMParam): +class ExtractorParam(ProcessParamBase, LLMParam): def __init__(self): super().__init__() self.field_name = "" @@ -26,7 +27,7 @@ class ExtractorParam(LLMParam): self.check_empty(self.field_name, "Result Destination") -class Extractor(LLM): +class Extractor(ProcessBase, LLM): component_name = "Extractor" async def _invoke(self, **kwargs): diff --git a/rag/flow/extractor/schema.py b/rag/flow/extractor/schema.py index 214e500e2..542e31731 100644 --- a/rag/flow/extractor/schema.py +++ b/rag/flow/extractor/schema.py @@ -30,7 +30,7 @@ class ExtractorFromUpstream(BaseModel): json_result: list[dict[str, Any]] | None = Field(default=None, alias="json") markdown_result: str | None = Field(default=None, alias="markdown") text_result: str | None = Field(default=None, alias="text") - html_result: list[str] | None = Field(default=None, alias="html") + html_result: str | None = Field(default=None, alias="html") model_config = ConfigDict(populate_by_name=True, extra="forbid") diff --git a/rag/flow/hierarchical_merger/schema.py b/rag/flow/hierarchical_merger/schema.py index e45610fe5..2c59497ed 100644 --- a/rag/flow/hierarchical_merger/schema.py +++ b/rag/flow/hierarchical_merger/schema.py @@ -29,7 +29,7 @@ class HierarchicalMergerFromUpstream(BaseModel): json_result: list[dict[str, Any]] | None = Field(default=None, alias="json") markdown_result: str | None = Field(default=None, alias="markdown") text_result: str | None = Field(default=None, alias="text") - html_result: list[str] | None = Field(default=None, alias="html") + html_result: str | None = Field(default=None, alias="html") model_config = ConfigDict(populate_by_name=True, extra="forbid") diff --git a/rag/flow/pipeline.py b/rag/flow/pipeline.py index 8a30dce64..e76eb8e9c 100644 --- a/rag/flow/pipeline.py +++ b/rag/flow/pipeline.py @@ -143,6 +143,10 @@ class Pipeline(Graph): async def invoke(): nonlocal last_cpn, cpn_obj await cpn_obj.invoke(**last_cpn.output()) + #if inspect.iscoroutinefunction(cpn_obj.invoke): + # await cpn_obj.invoke(**last_cpn.output()) + #else: + # cpn_obj.invoke(**last_cpn.output()) async with trio.open_nursery() as nursery: nursery.start_soon(invoke) diff --git a/rag/flow/splitter/schema.py b/rag/flow/splitter/schema.py index cf097d792..9144a4d9b 100644 --- a/rag/flow/splitter/schema.py +++ b/rag/flow/splitter/schema.py @@ -30,7 +30,7 @@ class SplitterFromUpstream(BaseModel): json_result: list[dict[str, Any]] | None = Field(default=None, alias="json") markdown_result: str | None = Field(default=None, alias="markdown") text_result: str | None = Field(default=None, alias="text") - html_result: list[str] | None = Field(default=None, alias="html") + html_result: str | None = Field(default=None, alias="html") model_config = ConfigDict(populate_by_name=True, extra="forbid") diff --git a/rag/flow/tokenizer/schema.py b/rag/flow/tokenizer/schema.py index 7ba0c30a6..af47c87c8 100644 --- a/rag/flow/tokenizer/schema.py +++ b/rag/flow/tokenizer/schema.py @@ -31,7 +31,7 @@ class TokenizerFromUpstream(BaseModel): json_result: list[dict[str, Any]] | None = Field(default=None, alias="json") markdown_result: str | None = Field(default=None, alias="markdown") text_result: str | None = Field(default=None, alias="text") - html_result: list[str] | None = Field(default=None, alias="html") + html_result: str | None = Field(default=None, alias="html") model_config = ConfigDict(populate_by_name=True, extra="forbid") diff --git a/rag/flow/tokenizer/tokenizer.py b/rag/flow/tokenizer/tokenizer.py index a97e43ed2..6a73edd8d 100644 --- a/rag/flow/tokenizer/tokenizer.py +++ b/rag/flow/tokenizer/tokenizer.py @@ -119,7 +119,7 @@ class Tokenizer(ProcessBase): if ck.get("questions"): ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"])) if ck.get("keywords"): - ck["important_tks"] = rag_tokenizer.tokenize("\n".join(ck["keywords"])) + ck["important_tks"] = rag_tokenizer.tokenize(",".join(ck["keywords"])) if ck.get("summary"): ck["content_ltks"] = rag_tokenizer.tokenize(ck["summary"]) ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"]) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index c8b07e758..2f611c2d3 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -222,10 +222,9 @@ async def collect(): return None, None canceled = False - if msg.get("doc_id", "") == GRAPH_RAPTOR_FAKE_DOC_ID: + if msg.get("doc_id", "") in [GRAPH_RAPTOR_FAKE_DOC_ID, CANVAS_DEBUG_DOC_ID]: task = msg if task["task_type"] in ["graphrag", "raptor"] and msg.get("doc_ids", []): - print(f"hack {msg['doc_ids']=}=",flush=True) task = TaskService.get_task(msg["id"], msg["doc_ids"]) task["doc_ids"] = msg["doc_ids"] else: