be more specific for error message (#1409)

### What problem does this PR solve?

#918 

### Type of change

- [x] Refactoring
This commit is contained in:
KevinHuSh
2024-07-08 09:32:44 +08:00
committed by GitHub
parent dcb3fb2073
commit b3ebc66b13
9 changed files with 126 additions and 61 deletions

View File

@ -121,7 +121,6 @@ class Canvas(ABC):
if desc["to"] not in cpn["downstream"]:
cpn["downstream"].append(desc["to"])
self.path = self.dsl["path"]
self.history = self.dsl["history"]
self.messages = self.dsl["messages"]
@ -136,9 +135,21 @@ class Canvas(ABC):
self.dsl["answer"] = self.answer
self.dsl["reference"] = self.reference
self.dsl["embed_id"] = self._embed_id
dsl = deepcopy(self.dsl)
dsl = {
"components": {}
}
for k in self.dsl.keys():
if k in ["components"]:continue
dsl[k] = deepcopy(self.dsl[k])
for k, cpn in self.components.items():
dsl["components"][k]["obj"] = json.loads(str(cpn["obj"]))
if k not in dsl["components"]:
dsl["components"][k] = {}
for c in cpn.keys():
if c == "obj":
dsl["components"][k][c] = json.loads(str(cpn["obj"]))
continue
dsl["components"][k][c] = deepcopy(cpn[c])
return json.dumps(dsl, ensure_ascii=False)
def reset(self):
@ -161,6 +172,9 @@ class Canvas(ABC):
except Exception as e:
ans = ComponentBase.be_output(str(e))
self.path[-1].append(cpn_id)
if kwargs.get("stream"):
assert isinstance(ans, partial)
return ans
self.history.append(("assistant", ans.to_dict("records")))
return ans
@ -190,6 +204,8 @@ class Canvas(ABC):
cpn = self.get_component(cpn_id)
if not cpn["downstream"]: break
if self._find_loop(): raise OverflowError("Too much loops!")
if cpn["obj"].component_name.lower() in ["switch", "categorize", "relevant"]:
switch_out = cpn["obj"].output()[1].iloc[0, 0]
assert switch_out in self.components, \
@ -249,3 +265,27 @@ class Canvas(ABC):
def get_embedding_model(self):
return self._embed_id
def _find_loop(self, max_loops=2):
path = self.path[-1][::-1]
if len(path) < 2: return False
for i in range(len(path)):
if path[i].lower().find("answer") >= 0:
path = path[:i]
break
if len(path) < 2: return False
for l in range(1, len(path) // 2):
pat = ",".join(path[0:l])
path_str = ",".join(path)
if len(pat) >= len(path_str): return False
path_str = path_str[len(pat):]
loop = max_loops
while path_str.find(pat) >= 0 and loop >= 0:
loop -= 1
path_str = path_str[len(pat):]
if loop < 0: return True
return False