Test APIs and fix bugs (#41)

This commit is contained in:
KevinHuSh
2024-01-22 19:51:38 +08:00
committed by GitHub
parent 484e5abc1f
commit 34b2ab3b2f
11 changed files with 46 additions and 27 deletions

View File

@ -19,31 +19,39 @@ import os
class Base(ABC):
def __init__(self, key, model_name):
pass
def chat(self, system, history, gen_conf):
raise NotImplementedError("Please implement encode method!")
class GptTurbo(Base):
def __init__(self):
self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
def __init__(self, key, model_name="gpt-3.5-turbo"):
self.client = OpenAI(api_key=key)
self.model_name = model_name
def chat(self, system, history, gen_conf):
history.insert(0, {"role": "system", "content": system})
res = self.client.chat.completions.create(
model="gpt-3.5-turbo",
model=self.model_name,
messages=history,
**gen_conf)
return res.choices[0].message.content.strip()
from dashscope import Generation
class QWenChat(Base):
def __init__(self, key, model_name=Generation.Models.qwen_turbo):
import dashscope
dashscope.api_key = key
self.model_name = model_name
def chat(self, system, history, gen_conf):
from http import HTTPStatus
from dashscope import Generation
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
history.insert(0, {"role": "system", "content": system})
response = Generation.call(
Generation.Models.qwen_turbo,
self.model_name,
messages=history,
result_format='message'
)

View File

@ -28,6 +28,8 @@ class Base(ABC):
raise NotImplementedError("Please implement encode method!")
def image2base64(self, image):
if isinstance(image, bytes):
return base64.b64encode(image).decode("utf-8")
if isinstance(image, BytesIO):
return base64.b64encode(image.getvalue()).decode("utf-8")
buffered = BytesIO()
@ -59,7 +61,7 @@ class Base(ABC):
class GptV4(Base):
def __init__(self, key, model_name="gpt-4-vision-preview"):
self.client = OpenAI(key)
self.client = OpenAI(api_key = key)
self.model_name = model_name
def describe(self, image, max_tokens=300):

View File

@ -187,9 +187,10 @@ class Dealer:
if len(t) < 5: continue
idx.append(i)
pieces_.append(t)
es_logger.info("{} => {}".format(answer, pieces_))
if not pieces_: return answer
ans_v = embd_mdl.encode(pieces_)
ans_v, c = embd_mdl.encode(pieces_)
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
len(ans_v[0]), len(chunk_v[0]))
@ -219,7 +220,7 @@ class Dealer:
Dealer.trans2floats(
sres.field[i]["q_%d_vec" % len(sres.query_vector)]) for i in sres.ids]
if not ins_embd:
return []
return [], [], []
ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ") for i in sres.ids]
sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
ins_embd,
@ -235,6 +236,8 @@ class Dealer:
def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True):
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
if not question: return ranks
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": top,
"question": question, "vector": True,
"similarity": similarity_threshold}
@ -243,7 +246,7 @@ class Dealer:
sim, tsim, vsim = self.rerank(
sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
idx = np.argsort(sim * -1)
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
dim = len(sres.query_vector)
start_idx = (page - 1) * page_size
for i in idx:

View File

@ -78,6 +78,7 @@ def chuck_doc(name, binary, cvmdl=None):
field = TextChunker.Fields()
field.text_chunks = [(txt, binary)]
field.table_chunks = []
return field
return TextChunker()(binary)
@ -161,9 +162,9 @@ def build(row, cvmdl):
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
output_buffer = BytesIO()
docs = []
md5 = hashlib.md5()
for txt, img in obj.text_chunks:
d = copy.deepcopy(doc)
md5 = hashlib.md5()
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest()
d["content_ltks"] = huqie.qie(txt)
@ -186,6 +187,7 @@ def build(row, cvmdl):
for i, txt in enumerate(arr):
d = copy.deepcopy(doc)
d["content_ltks"] = huqie.qie(txt)
md5 = hashlib.md5()
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest()
if not img:
@ -226,9 +228,6 @@ def embedding(docs, mdl):
def main(comm, mod):
global model
from rag.llm import HuEmbedding
model = HuEmbedding()
tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
tm = findMaxTm(tm_fnm)
rows = collect(comm, mod, tm)
@ -260,13 +259,14 @@ def main(comm, mod):
set_progress(r["id"], random.randint(70, 95) / 100.,
"Finished embedding! Start to build index!")
init_kb(r)
chunk_count = len(set([c["_id"] for c in cks]))
es_r = ELASTICSEARCH.bulk(cks, search.index_name(r["tenant_id"]))
if es_r:
set_progress(r["id"], -1, "Index failure!")
cron_logger.error(str(es_r))
else:
set_progress(r["id"], 1., "Done!")
DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count, len(cks), timer()-st_tm)
DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count, chunk_count, timer()-st_tm)
cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks)))
tmf.write(str(r["update_time"]) + "\n")