From 3245107dc7c8dab4f25795bff51dfe62ef6e85c0 Mon Sep 17 00:00:00 2001 From: KevinHuSh Date: Mon, 25 Dec 2023 19:05:59 +0800 Subject: [PATCH] use minio to store uploaded files; build dialog server; (#16) * format code * use minio to store uploaded files; build dialog server; --- python/llm/__init__.py | 1 + python/llm/chat_model.py | 34 ++++++ python/llm/embedding_model.py | 5 +- python/nlp/huchunk.py | 6 +- python/nlp/search.py | 221 ++++++++++++++++++++++++++++++++++ python/parser/docx_parser.py | 3 +- python/parser/excel_parser.py | 4 +- python/parser/pdf_parser.py | 3 +- python/svr/dialog_svr.py | 164 +++++++++++++++++++++++++ python/svr/parse_user_docs.py | 25 ++-- src/api/doc_info.rs | 81 ++++++++----- src/api/tag.rs | 58 --------- src/main.rs | 49 ++++---- 13 files changed, 520 insertions(+), 134 deletions(-) create mode 100644 python/llm/chat_model.py create mode 100644 python/nlp/search.py create mode 100755 python/svr/dialog_svr.py delete mode 100644 src/api/tag.rs diff --git a/python/llm/__init__.py b/python/llm/__init__.py index 07bb78a74..1352d5b96 100644 --- a/python/llm/__init__.py +++ b/python/llm/__init__.py @@ -1 +1,2 @@ from .embedding_model import HuEmbedding +from .chat_model import GptTurbo diff --git a/python/llm/chat_model.py b/python/llm/chat_model.py new file mode 100644 index 000000000..49c2518e0 --- /dev/null +++ b/python/llm/chat_model.py @@ -0,0 +1,34 @@ +from abc import ABC +import openapi +import os + +class Base(ABC): + def chat(self, system, history, gen_conf): + raise NotImplementedError("Please implement encode method!") + + +class GptTurbo(Base): + def __init__(self): + openapi.api_key = os.environ["OPENAPI_KEY"] + + def chat(self, system, history, gen_conf): + history.insert(0, {"role": "system", "content": system}) + res = openapi.ChatCompletion.create(model="gpt-3.5-turbo", + messages=history, + **gen_conf) + return res.choices[0].message.content.strip() + + +class QWen(Base): + def chat(self, system, history, gen_conf): + from http import HTTPStatus + from dashscope import Generation + from dashscope.api_entities.dashscope_response import Role + response = Generation.call( + Generation.Models.qwen_turbo, + messages=messages, + result_format='message' + ) + if response.status_code == HTTPStatus.OK: + return response.output.choices[0]['message']['content'] + return response.message diff --git a/python/llm/embedding_model.py b/python/llm/embedding_model.py index cbb8cf553..98a483275 100644 --- a/python/llm/embedding_model.py +++ b/python/llm/embedding_model.py @@ -1,6 +1,7 @@ from abc import ABC from FlagEmbedding import FlagModel import torch +import numpy as np class Base(ABC): def encode(self, texts: list, batch_size=32): @@ -27,5 +28,5 @@ class HuEmbedding(Base): def encode(self, texts: list, batch_size=32): res = [] for i in range(0, len(texts), batch_size): - res.extend(self.encode(texts[i:i+batch_size])) - return res + res.extend(self.model.encode(texts[i:i+batch_size]).tolist()) + return np.array(res) diff --git a/python/nlp/huchunk.py b/python/nlp/huchunk.py index 4e5ca8e87..da72ce364 100644 --- a/python/nlp/huchunk.py +++ b/python/nlp/huchunk.py @@ -372,7 +372,7 @@ class PptChunker(HuChunker): def __call__(self, fnm): from pptx import Presentation - ppt = Presentation(fnm) + ppt = Presentation(fnm) if isinstance(fnm, str) else Presentation(BytesIO(fnm)) flds = self.Fields() flds.text_chunks = [] for slide in ppt.slides: @@ -396,7 +396,9 @@ class TextChunker(HuChunker): @staticmethod def is_binary_file(file_path): mime = magic.Magic(mime=True) - file_type = mime.from_file(file_path) + if isinstance(file_path, str): + file_type = mime.from_file(file_path) + else:file_type = mime.from_buffer(file_path) if 'text' in file_type: return False else: diff --git a/python/nlp/search.py b/python/nlp/search.py new file mode 100644 index 000000000..e751b66dd --- /dev/null +++ b/python/nlp/search.py @@ -0,0 +1,221 @@ +import re +from elasticsearch_dsl import Q,Search,A +from typing import List, Optional, Tuple,Dict, Union +from dataclasses import dataclass +from util import setup_logging, rmSpace +from nlp import huqie, query +from datetime import datetime +from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity +import numpy as np +from copy import deepcopy + +class Dealer: + def __init__(self, es, emb_mdl): + self.qryr = query.EsQueryer(es) + self.qryr.flds = ["title_tks^10", "title_sm_tks^5", "content_ltks^2", "content_sm_ltks"] + self.es = es + self.emb_mdl = emb_mdl + + @dataclass + class SearchResult: + total:int + ids: List[str] + query_vector: List[float] = None + field: Optional[Dict] = None + highlight: Optional[Dict] = None + aggregation: Union[List, Dict, None] = None + keywords: Optional[List[str]] = None + group_docs: List[List] = None + + def _vector(self, txt, sim=0.8, topk=10): + return { + "field": "q_vec", + "k": topk, + "similarity": sim, + "num_candidates": 1000, + "query_vector": self.emb_mdl.encode_queries(txt) + } + + def search(self, req, idxnm, tks_num=3): + keywords = [] + qst = req.get("question", "") + + bqry,keywords = self.qryr.question(qst) + if req.get("kb_ids"): bqry.filter.append(Q("terms", kb_id=req["kb_ids"])) + bqry.filter.append(Q("exists", field="q_tks")) + bqry.boost = 0.05 + print(bqry) + + s = Search() + pg = int(req.get("page", 1))-1 + ps = int(req.get("size", 1000)) + src = req.get("field", ["docnm_kwd", "content_ltks", "kb_id", + "image_id", "doc_id", "q_vec"]) + + s = s.query(bqry)[pg*ps:(pg+1)*ps] + s = s.highlight("content_ltks") + s = s.highlight("title_ltks") + if not qst: s = s.sort({"create_time":{"order":"desc", "unmapped_type":"date"}}) + + s = s.highlight_options( + fragment_size = 120, + number_of_fragments=5, + boundary_scanner_locale="zh-CN", + boundary_scanner="SENTENCE", + boundary_chars=",./;:\\!(),。?:!……()——、" + ) + s = s.to_dict() + q_vec = [] + if req.get("vector"): + s["knn"] = self._vector(qst, req.get("similarity", 0.4), ps) + s["knn"]["filter"] = bqry.to_dict() + del s["highlight"] + q_vec = s["knn"]["query_vector"] + res = self.es.search(s, idxnm=idxnm, timeout="600s",src=src) + print("TOTAL: ", self.es.getTotal(res)) + if self.es.getTotal(res) == 0 and "knn" in s: + bqry,_ = self.qryr.question(qst, min_match="10%") + if req.get("kb_ids"): bqry.filter.append(Q("terms", kb_id=req["kb_ids"])) + s["query"] = bqry.to_dict() + s["knn"]["filter"] = bqry.to_dict() + s["knn"]["similarity"] = 0.7 + res = self.es.search(s, idxnm=idxnm, timeout="600s",src=src) + + kwds = set([]) + for k in keywords: + kwds.add(k) + for kk in huqie.qieqie(k).split(" "): + if len(kk) < 2:continue + if kk in kwds:continue + kwds.add(kk) + + aggs = self.getAggregation(res, "docnm_kwd") + + return self.SearchResult( + total = self.es.getTotal(res), + ids = self.es.getDocIds(res), + query_vector = q_vec, + aggregation = aggs, + highlight = self.getHighlight(res), + field = self.getFields(res, ["docnm_kwd", "content_ltks", + "kb_id","image_id", "doc_id", "q_vec"]), + keywords = list(kwds) + ) + + def getAggregation(self, res, g): + if not "aggregations" in res or "aggs_"+g not in res["aggregations"]:return + bkts = res["aggregations"]["aggs_"+g]["buckets"] + return [(b["key"], b["doc_count"]) for b in bkts] + + def getHighlight(self, res): + def rmspace(line): + eng = set(list("qwertyuioplkjhgfdsazxcvbnm")) + r = [] + for t in line.split(" "): + if not t:continue + if len(r)>0 and len(t)>0 and r[-1][-1] in eng and t[0] in eng:r.append(" ") + r.append(t) + r = "".join(r) + return r + + ans = {} + for d in res["hits"]["hits"]: + hlts = d.get("highlight") + if not hlts:continue + ans[d["_id"]] = "".join([a for a in list(hlts.items())[0][1]]) + return ans + + def getFields(self, sres, flds): + res = {} + if not flds:return {} + for d in self.es.getSource(sres): + m = {n:d.get(n) for n in flds if d.get(n) is not None} + for n,v in m.items(): + if type(v) == type([]): + m[n] = "\t".join([str(vv) for vv in v]) + continue + if type(v) != type(""):m[n] = str(m[n]) + m[n] = rmSpace(m[n]) + + if m:res[d["id"]] = m + return res + + + @staticmethod + def trans2floats(txt): + return [float(t) for t in txt.split("\t")] + + + def insert_citations(self, ans, top_idx, sres, vfield = "q_vec", cfield="content_ltks"): + + ins_embd = [Dealer.trans2floats(sres.field[sres.ids[i]][vfield]) for i in top_idx] + ins_tw =[sres.field[sres.ids[i]][cfield].split(" ") for i in top_idx] + s = 0 + e = 0 + res = "" + def citeit(): + nonlocal s, e, ans, res + if not ins_embd:return + embd = self.emb_mdl.encode(ans[s: e]) + sim = self.qryr.hybrid_similarity(embd, + ins_embd, + huqie.qie(ans[s:e]).split(" "), + ins_tw) + print(ans[s: e], sim) + mx = np.max(sim)*0.99 + if mx < 0.55:return + cita = list(set([top_idx[i] for i in range(len(ins_embd)) if sim[i] >mx]))[:4] + for i in cita: res += f"@?{i}?@" + + return cita + + punct = set(";。?!!") + if not self.qryr.isChinese(ans): + punct.add("?") + punct.add(".") + while e < len(ans): + if e - s < 12 or ans[e] not in punct: + e += 1 + continue + if ans[e] == "." and e+1=0 and ans[e-2] == "\n": + e += 1 + continue + res += ans[s: e] + citeit() + res += ans[e] + e += 1 + s = e + + if s< len(ans): + res += ans[s:] + citeit() + + return res + + + def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, vfield="q_vec", cfield="content_ltks"): + ins_embd = [Dealer.trans2floats(sres.field[i]["q_vec"]) for i in sres.ids] + if not ins_embd: return [] + ins_tw =[sres.field[i][cfield].split(" ") for i in sres.ids] + #return CosineSimilarity([sres.query_vector], ins_embd)[0] + sim = self.qryr.hybrid_similarity(sres.query_vector, + ins_embd, + huqie.qie(query).split(" "), + ins_tw, tkweight, vtweight) + return sim + + + +if __name__ == "__main__": + from util import es_conn + SE = Dealer(es_conn.HuEs("infiniflow")) + qs = [ + "胡凯", + "" + ] + for q in qs: + print(">>>>>>>>>>>>>>>>>>>>", q) + print(SE.search({"question": q, "kb_ids": "64f072a75f3b97c865718c4a"}, "infiniflow_*")) diff --git a/python/parser/docx_parser.py b/python/parser/docx_parser.py index 1b5a3efcd..5968b0eae 100644 --- a/python/parser/docx_parser.py +++ b/python/parser/docx_parser.py @@ -3,6 +3,7 @@ import re import pandas as pd from collections import Counter from nlp import huqie +from io import BytesIO class HuDocxParser: @@ -97,7 +98,7 @@ class HuDocxParser: return ["\n".join(lines)] def __call__(self, fnm): - self.doc = Document(fnm) + self.doc = Document(fnm) if isinstance(fnm, str) else Document(BytesIO(fnm)) secs = [(p.text, p.style.name) for p in self.doc.paragraphs] tbls = [self.__extract_table_content(tb) for tb in self.doc.tables] return secs, tbls diff --git a/python/parser/excel_parser.py b/python/parser/excel_parser.py index d03c1d78c..e0f931f99 100644 --- a/python/parser/excel_parser.py +++ b/python/parser/excel_parser.py @@ -1,10 +1,12 @@ from openpyxl import load_workbook import sys +from io import BytesIO class HuExcelParser: def __call__(self, fnm): - wb = load_workbook(fnm) + if isinstance(fnm, str):wb = load_workbook(fnm) + else: wb = load_workbook(BytesIO(fnm)) res = [] for sheetname in wb.sheetnames: ws = wb[sheetname] diff --git a/python/parser/pdf_parser.py b/python/parser/pdf_parser.py index 7fd341518..7f009b929 100644 --- a/python/parser/pdf_parser.py +++ b/python/parser/pdf_parser.py @@ -1,4 +1,5 @@ import xgboost as xgb +from io import BytesIO import torch import re import pdfplumber @@ -1525,7 +1526,7 @@ class HuParser: return "\n\n".join(res) def __call__(self, fnm, need_image=True, zoomin=3, return_html=False): - self.pdf = pdfplumber.open(fnm) + self.pdf = pdfplumber.open(fnm) if isinstance(fnm, str) else pdfplumber.open(BytesIO(fnm)) self.lefted_chars = [] self.mean_height = [] self.mean_width = [] diff --git a/python/svr/dialog_svr.py b/python/svr/dialog_svr.py new file mode 100755 index 000000000..5d683d66d --- /dev/null +++ b/python/svr/dialog_svr.py @@ -0,0 +1,164 @@ +#-*- coding:utf-8 -*- +import sys, os, re,inspect,json,traceback,logging,argparse, copy +sys.path.append(os.path.realpath(os.path.dirname(inspect.getfile(inspect.currentframe())))+"/../") +from tornado.web import RequestHandler,Application +from tornado.ioloop import IOLoop +from tornado.httpserver import HTTPServer +from tornado.options import define,options +from util import es_conn, setup_logging +from svr import sec_search as search +from svr.rpc_proxy import RPCProxy +from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity +from nlp import huqie +from nlp import query as Query +from llm import HuEmbedding, GptTurbo +import numpy as np +from io import BytesIO +from util import config +from timeit import default_timer as timer +from collections import OrderedDict + +SE = None +CFIELD="content_ltks" +EMBEDDING = HuEmbedding() +LLM = GptTurbo() + +def get_QA_pairs(hists): + pa = [] + for h in hists: + for k in ["user", "assistant"]: + if h.get(k): + pa.append({ + "content": h[k], + "role": k, + }) + + for p in pa[:-1]: assert len(p) == 2, p + return pa + + + +def get_instruction(sres, top_i, max_len=8096 fld="content_ltks"): + max_len //= len(top_i) + # add instruction to prompt + instructions = [re.sub(r"[\r\n]+", " ", sres.field[sres.ids[i]][fld]) for i in top_i] + if len(instructions)>2: + # Said that LLM is sensitive to the first and the last one, so + # rearrange the order of references + instructions.append(copy.deepcopy(instructions[1])) + instructions.pop(1) + + def token_num(txt): + c = 0 + for tk in re.split(r"[,。/?‘’”“:;:;!!]", txt): + if re.match(r"[a-zA-Z-]+$", tk): + c += 1 + continue + c += len(tk) + return c + + _inst = "" + for ins in instructions: + if token_num(_inst) > 4096: + _inst += "\n知识库:" + instructions[-1][:max_len] + break + _inst += "\n知识库:" + ins[:max_len] + return _inst + + +def prompt_and_answer(history, inst): + hist = get_QA_pairs(history) + chks = [] + for s in re.split(r"[::;;。\n\r]+", inst): + if s: chks.append(s) + chks = len(set(chks))/(0.1+len(chks)) + print("Duplication portion:", chks) + + system = """ +你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答%s。当所有知识库内容都与问题无关时,你的回答必须包括"知识库中未找到您要的答案!这是我所知道的,仅作参考。"这句话。回答需要考虑聊天历史。 +以下是知识库: +%s +以上是知识库。 +"""%((",最好总结成表格" if chks<0.6 and chks>0 else ""), inst) + + print("【PROMPT】:", system) + start = timer() + response = LLM.chat(system, hist, {"temperature": 0.2, "max_tokens": 512}) + print("GENERATE: ", timer()-start) + print("===>>", response) + return response + + +class Handler(RequestHandler): + def post(self): + global SE,MUST_TK_NUM + param = json.loads(self.request.body.decode('utf-8')) + try: + question = param.get("history",[{"user": "Hi!"}])[-1]["user"] + res = SE.search({ + "question": question, + "kb_ids": param.get("kb_ids", []), + "size": param.get("topn", 15) + }) + + sim = SE.rerank(res, question) + rk_idx = np.argsort(sim*-1) + topidx = [i for i in rk_idx if sim[i] >= aram.get("similarity", 0.5)][:param.get("topn",12)] + inst = get_instruction(res, topidx) + + ans, topidx = prompt_and_answer(param["history"], inst) + ans = SE.insert_citations(ans, topidx, res) + + refer = OrderedDict() + docnms = {} + for i in rk_idx: + did = res.field[res.ids[i]]["doc_id"]) + if did not in docnms: docnms[did] = res.field[res.ids[i]]["docnm_kwd"]) + if did not in refer: refer[did] = [] + refer[did].append({ + "chunk_id": res.ids[i], + "content": res.field[res.ids[i]]["content_ltks"]), + "image": "" + }) + + print("::::::::::::::", ans) + self.write(json.dumps({ + "code":0, + "msg":"success", + "data":{ + "uid": param["uid"], + "dialog_id": param["dialog_id"], + "assistant": ans + "refer": [{ + "did": did, + "doc_name": docnms[did], + "chunks": chunks + } for did, chunks in refer.items()] + } + })) + logging.info("SUCCESS[%d]"%(res.total)+json.dumps(param, ensure_ascii=False)) + + except Exception as e: + logging.error("Request 500: "+str(e)) + self.write(json.dumps({ + "code":500, + "msg":str(e), + "data":{} + })) + print(traceback.format_exc()) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--port", default=4455, type=int, help="Port used for service") + ARGS = parser.parse_args() + + SE = search.ResearchReportSearch(es_conn.HuEs("infiniflow"), EMBEDDING) + + app = Application([(r'/v1/chat/completions', Handler)],debug=False) + http_server = HTTPServer(app) + http_server.bind(ARGS.port) + http_server.start(3) + + IOLoop.current().start() + diff --git a/python/svr/parse_user_docs.py b/python/svr/parse_user_docs.py index 742fe5124..573cd1a3d 100644 --- a/python/svr/parse_user_docs.py +++ b/python/svr/parse_user_docs.py @@ -34,18 +34,14 @@ DOC = DocxChunker(DocxParser()) EXC = ExcelChunker(ExcelParser()) PPT = PptChunker() -UPLOAD_LOCATION = os.environ.get("UPLOAD_LOCATION", "./") -logging.warning(f"The files are stored in {UPLOAD_LOCATION}, please check it!") - - -def chuck_doc(name): +def chuck_doc(name, binary): suff = os.path.split(name)[-1].lower().split(".")[-1] - if suff.find("pdf") >= 0: return PDF(name) - if suff.find("doc") >= 0: return DOC(name) - if re.match(r"(xlsx|xlsm|xltx|xltm)", suff): return EXC(name) - if suff.find("ppt") >= 0: return PPT(name) + if suff.find("pdf") >= 0: return PDF(binary) + if suff.find("doc") >= 0: return DOC(binary) + if re.match(r"(xlsx|xlsm|xltx|xltm)", suff): return EXC(binary) + if suff.find("ppt") >= 0: return PPT(binary) - return TextChunker()(name) + return TextChunker()(binary) def collect(comm, mod, tm): @@ -115,7 +111,7 @@ def build(row): random.seed(time.time()) set_progress(row["kb2doc_id"], random.randint(0, 20)/100., "Finished preparing! Start to slice file!") try: - obj = chuck_doc(os.path.join(UPLOAD_LOCATION, row["location"])) + obj = chuck_doc(row["doc_name"], MINIO.get("%s-upload"%str(row["uid"]), row["location"])) except Exception as e: if re.search("(No such file|not found)", str(e)): set_progress(row["kb2doc_id"], -1, "Can not find file <%s>"%row["doc_name"]) @@ -133,9 +129,11 @@ def build(row): doc = { "doc_id": row["did"], "kb_id": [str(row["kb_id"])], + "docnm_kwd": os.path.split(row["location"])[-1], "title_tks": huqie.qie(os.path.split(row["location"])[-1]), "updated_at": str(row["updated_at"]).replace("T", " ")[:19] } + doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"]) output_buffer = BytesIO() docs = [] md5 = hashlib.md5() @@ -144,11 +142,14 @@ def build(row): md5.update((txt + str(d["doc_id"])).encode("utf-8")) d["_id"] = md5.hexdigest() d["content_ltks"] = huqie.qie(txt) + d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) if not img: docs.append(d) continue img.save(output_buffer, format='JPEG') - d["img_bin"] = str(output_buffer.getvalue()) + MINIO.put("{}-{}".format(row["uid"], row["kb_id"]), d["_id"], + output_buffer.getvalue()) + d["img_id"] = "{}-{}".format(row["uid"], row["kb_id"]) docs.append(d) for arr, img in obj.table_chunks: diff --git a/src/api/doc_info.rs b/src/api/doc_info.rs index 84886d4ab..ba8006b39 100644 --- a/src/api/doc_info.rs +++ b/src/api/doc_info.rs @@ -1,9 +1,9 @@ use std::collections::HashMap; -use std::io::Write; -use actix_multipart_extract::{File, Multipart, MultipartForm}; -use actix_web::{get, HttpResponse, post, web}; -use chrono::{Utc, FixedOffset}; -use minio::s3::args::{BucketExistsArgs, MakeBucketArgs, UploadObjectArgs}; +use std::io::BufReader; +use actix_multipart_extract::{ File, Multipart, MultipartForm }; +use actix_web::{ HttpResponse, post, web }; +use chrono::{ Utc, FixedOffset }; +use minio::s3::args::{ BucketExistsArgs, MakeBucketArgs, PutObjectArgs }; use sea_orm::DbConn; use crate::api::JsonResponse; use crate::AppState; @@ -12,9 +12,6 @@ use crate::errors::AppError; use crate::service::doc_info::{ Mutation, Query }; use serde::Deserialize; -const BUCKET_NAME: &'static str = "docgpt-upload"; - - fn now() -> chrono::DateTime { Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap()) } @@ -74,53 +71,71 @@ async fn upload( ) -> Result { let uid = payload.uid; let file_name = payload.file_field.name.as_str(); - async fn add_number_to_filename(file_name: &str, conn:&DbConn, uid:i64, parent_id:i64) -> String { + async fn add_number_to_filename( + file_name: &str, + conn: &DbConn, + uid: i64, + parent_id: i64 + ) -> String { let mut i = 0; let mut new_file_name = file_name.to_string(); let arr: Vec<&str> = file_name.split(".").collect(); - let suffix = String::from(arr[arr.len()-1]); - let preffix = arr[..arr.len()-1].join("."); - let mut docs = Query::find_doc_infos_by_name(conn, uid, &new_file_name, Some(parent_id)).await.unwrap(); - while docs.len()>0 { + let suffix = String::from(arr[arr.len() - 1]); + let preffix = arr[..arr.len() - 1].join("."); + let mut docs = Query::find_doc_infos_by_name( + conn, + uid, + &new_file_name, + Some(parent_id) + ).await.unwrap(); + while docs.len() > 0 { i += 1; new_file_name = format!("{}_{}.{}", preffix, i, suffix); - docs = Query::find_doc_infos_by_name(conn, uid, &new_file_name, Some(parent_id)).await.unwrap(); + docs = Query::find_doc_infos_by_name( + conn, + uid, + &new_file_name, + Some(parent_id) + ).await.unwrap(); } new_file_name } let fnm = add_number_to_filename(file_name, &data.conn, uid, payload.did).await; - let s3_client = &data.s3_client; + let bucket_name = format!("{}-upload", payload.uid); + let s3_client: &minio::s3::client::Client = &data.s3_client; let buckets_exists = s3_client - .bucket_exists(&BucketExistsArgs::new(BUCKET_NAME)?) - .await?; + .bucket_exists(&BucketExistsArgs::new(&bucket_name).unwrap()).await + .unwrap(); if !buckets_exists { - s3_client - .make_bucket(&MakeBucketArgs::new(BUCKET_NAME)?) - .await?; + print!("Create bucket: {}", bucket_name.clone()); + s3_client.make_bucket(&MakeBucketArgs::new(&bucket_name).unwrap()).await.unwrap(); + } else { + print!("Existing bucket: {}", bucket_name.clone()); } - s3_client - .upload_object( - &mut UploadObjectArgs::new( - BUCKET_NAME, - fnm.as_str(), - format!("/{}/{}-{}", payload.uid, payload.did, fnm).as_str() - )? - ) - .await?; + let location = format!("/{}/{}", payload.did, fnm); + print!("===>{}", location.clone()); + s3_client.put_object( + &mut PutObjectArgs::new( + &bucket_name, + &location, + &mut BufReader::new(payload.file_field.bytes.as_slice()), + Some(payload.file_field.bytes.len()), + None + )? + ).await?; - let location = format!("/{}/{}", BUCKET_NAME, fnm); let doc = Mutation::create_doc_info(&data.conn, Model { - did:Default::default(), - uid: uid, + did: Default::default(), + uid: uid, doc_name: fnm, size: payload.file_field.bytes.len() as i64, location, r#type: "doc".to_string(), created_at: now(), updated_at: now(), - is_deleted:Default::default(), + is_deleted: Default::default(), }).await?; let _ = Mutation::place_doc(&data.conn, payload.did, doc.did.unwrap()).await?; diff --git a/src/api/tag.rs b/src/api/tag.rs deleted file mode 100644 index b902f3d3a..000000000 --- a/src/api/tag.rs +++ /dev/null @@ -1,58 +0,0 @@ -use std::collections::HashMap; -use actix_web::{get, HttpResponse, post, web}; -use actix_web::http::Error; -use crate::api::JsonResponse; -use crate::AppState; -use crate::entity::tag_info; -use crate::service::tag_info::{Mutation, Query}; - -#[post("/v1.0/create_tag")] -async fn create(model: web::Json, data: web::Data) -> Result { - let model = Mutation::create_tag(&data.conn, model.into_inner()).await.unwrap(); - - let mut result = HashMap::new(); - result.insert("tid", model.tid.unwrap()); - - let json_response = JsonResponse { - code: 200, - err: "".to_owned(), - data: result, - }; - - Ok(HttpResponse::Ok() - .content_type("application/json") - .body(serde_json::to_string(&json_response).unwrap())) -} - -#[post("/v1.0/delete_tag")] -async fn delete(model: web::Json, data: web::Data) -> Result { - let _ = Mutation::delete_tag(&data.conn, model.tid).await.unwrap(); - - let json_response = JsonResponse { - code: 200, - err: "".to_owned(), - data: (), - }; - - Ok(HttpResponse::Ok() - .content_type("application/json") - .body(serde_json::to_string(&json_response).unwrap())) -} - -#[get("/v1.0/tags")] -async fn list(data: web::Data) -> Result { - let tags = Query::find_tag_infos(&data.conn).await.unwrap(); - - let mut result = HashMap::new(); - result.insert("tags", tags); - - let json_response = JsonResponse { - code: 200, - err: "".to_owned(), - data: result, - }; - - Ok(HttpResponse::Ok() - .content_type("application/json") - .body(serde_json::to_string(&json_response).unwrap())) -} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index f99c4c397..e301677ed 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,9 +5,9 @@ mod errors; use std::env; use actix_files::Files; -use actix_identity::{CookieIdentityPolicy, IdentityService, RequestIdentity}; +use actix_identity::{ CookieIdentityPolicy, IdentityService, RequestIdentity }; use actix_session::CookieSession; -use actix_web::{web, App, HttpServer, middleware, Error}; +use actix_web::{ web, App, HttpServer, middleware, Error }; use actix_web::cookie::time::Duration; use actix_web::dev::ServiceRequest; use actix_web::error::ErrorUnauthorized; @@ -16,9 +16,9 @@ use listenfd::ListenFd; use minio::s3::client::Client; use minio::s3::creds::StaticProvider; use minio::s3::http::BaseUrl; -use sea_orm::{Database, DatabaseConnection}; -use migration::{Migrator, MigratorTrait}; -use crate::errors::{AppError, UserError}; +use sea_orm::{ Database, DatabaseConnection }; +use migration::{ Migrator, MigratorTrait }; +use crate::errors::{ AppError, UserError }; #[derive(Debug, Clone)] struct AppState { @@ -28,10 +28,10 @@ struct AppState { pub(crate) async fn validator( req: ServiceRequest, - credentials: BearerAuth, + credentials: BearerAuth ) -> Result { if let Some(token) = req.get_identity() { - println!("{}, {}",credentials.token(), token); + println!("{}, {}", credentials.token(), token); (credentials.token() == token) .then(|| req) .ok_or(ErrorUnauthorized(UserError::InvalidToken)) @@ -52,26 +52,25 @@ async fn main() -> Result<(), AppError> { let port = env::var("PORT").expect("PORT is not set in .env file"); let server_url = format!("{host}:{port}"); - let s3_base_url = env::var("S3_BASE_URL").expect("S3_BASE_URL is not set in .env file"); - let s3_access_key = env::var("S3_ACCESS_KEY").expect("S3_ACCESS_KEY is not set in .env file");; - let s3_secret_key = env::var("S3_SECRET_KEY").expect("S3_SECRET_KEY is not set in .env file");; + let mut s3_base_url = env::var("MINIO_HOST").expect("MINIO_HOST is not set in .env file"); + let s3_access_key = env::var("MINIO_USR").expect("MINIO_USR is not set in .env file"); + let s3_secret_key = env::var("MINIO_PWD").expect("MINIO_PWD is not set in .env file"); + if s3_base_url.find("http") != Some(0) { + s3_base_url = format!("http://{}", s3_base_url); + } // establish connection to database and apply migrations // -> create post table if not exists let conn = Database::connect(&db_url).await.unwrap(); Migrator::up(&conn, None).await.unwrap(); - let static_provider = StaticProvider::new( - s3_access_key.as_str(), - s3_secret_key.as_str(), - None, - ); + let static_provider = StaticProvider::new(s3_access_key.as_str(), s3_secret_key.as_str(), None); let s3_client = Client::new( s3_base_url.parse::()?, Some(Box::new(static_provider)), None, - None, + Some(true) )?; let state = AppState { conn, s3_client }; @@ -82,18 +81,20 @@ async fn main() -> Result<(), AppError> { App::new() .service(Files::new("/static", "./static")) .app_data(web::Data::new(state.clone())) - .wrap(IdentityService::new( - CookieIdentityPolicy::new(&[0; 32]) - .name("auth-cookie") - .login_deadline(Duration::seconds(120)) - .secure(false), - )) + .wrap( + IdentityService::new( + CookieIdentityPolicy::new(&[0; 32]) + .name("auth-cookie") + .login_deadline(Duration::seconds(120)) + .secure(false) + ) + ) .wrap( CookieSession::signed(&[0; 32]) .name("session-cookie") .secure(false) // WARNING(alex): This uses the `time` crate, not `std::time`! - .expires_in_time(Duration::seconds(60)), + .expires_in_time(Duration::seconds(60)) ) .wrap(middleware::Logger::default()) .configure(init) @@ -137,4 +138,4 @@ fn init(cfg: &mut web::ServiceConfig) { cfg.service(api::user_info::login); cfg.service(api::user_info::register); cfg.service(api::user_info::setting); -} \ No newline at end of file +}