diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 0a4aebe82..f54ebf709 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -619,7 +619,12 @@ def chat(dialog, messages, stream=True, **kwargs): def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None): - sys_prompt = "You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question." + sys_prompt = """ +You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question. +Ensure that: +1. Field names should not start with a digit. If any field name starts with a digit, use double quotes around it. +2. Write only the SQL, no explanations or additional text. +""" user_prompt = """ Table name: {}; Table of database fields are as follows: @@ -640,6 +645,7 @@ Please write the SQL, only SQL, without any other explanations or text. sql = re.sub(r".*select ", "select ", sql.lower()) sql = re.sub(r" +", " ", sql) sql = re.sub(r"([;;]|```).*", "", sql) + sql = re.sub(r"&", "and", sql) if sql[: len("select ")] != "select ": return None, None if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):