diff --git a/rag/app/table.py b/rag/app/table.py index 450cd6280..90b4c2849 100644 --- a/rag/app/table.py +++ b/rag/app/table.py @@ -92,11 +92,15 @@ def column_data_type(arr): arr = list(arr) counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0} trans = {t: f for f, t in [(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]} + float_flag = False for a in arr: if a is None: continue - if re.match(r"[+-]?[0-9]{,19}(\.0+)?$", str(a).replace("%%", "")): + if re.match(r"[+-]?[0-9]+$", str(a).replace("%%", "")): counts["int"] += 1 + if int(str(a)) > 2**63 - 1: + float_flag = True + break elif re.match(r"[+-]?[0-9.]{,19}$", str(a).replace("%%", "")): counts["float"] += 1 elif re.match(r"(true|yes|是|\*|✓|✔|☑|✅|√|false|no|否|⍻|×)$", str(a), flags=re.IGNORECASE): @@ -105,8 +109,11 @@ def column_data_type(arr): counts["datetime"] += 1 else: counts["text"] += 1 - counts = sorted(counts.items(), key=lambda x: x[1] * -1) - ty = counts[0][0] + if float_flag: + ty = "float" + else: + counts = sorted(counts.items(), key=lambda x: x[1] * -1) + ty = counts[0][0] for i in range(len(arr)): if arr[i] is None: continue