fix user login issue (#85)

This commit is contained in:
KevinHuSh
2024-02-29 14:03:07 +08:00
committed by GitHub
parent e86c820461
commit 0429107e80
12 changed files with 99 additions and 79 deletions

View File

@ -33,49 +33,14 @@ from api.utils.api_utils import get_json_result, cors_reponse
@manager.route('/login', methods=['POST', 'GET'])
def login():
userinfo = None
login_channel = "password"
if session.get("access_token"):
login_channel = session["access_token_from"]
if session["access_token_from"] == "github":
userinfo = user_info_from_github(session["access_token"])
elif not request.json:
if not request.json:
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
retmsg='Unautherized!')
email = request.json.get('email') if not userinfo else userinfo["email"]
email = request.json.get('email', "")
users = UserService.query(email=email)
if not users:
if request.json is not None:
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
avatar = ""
try:
avatar = download_img(userinfo["avatar_url"])
except Exception as e:
stat_logger.exception(e)
user_id = get_uuid()
try:
users = user_register(user_id, {
"access_token": session["access_token"],
"email": userinfo["email"],
"avatar": avatar,
"nickname": userinfo["login"],
"login_channel": login_channel,
"last_login_time": get_format_time(),
"is_superuser": False,
})
if not users: raise Exception('Register user failure.')
if len(users) > 1: raise Exception('Same E-mail exist!')
user = users[0]
login_user(user)
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
except Exception as e:
rollback_user_registration(user_id)
stat_logger.exception(e)
return server_error_response(e)
elif not request.json:
login_user(users[0])
return cors_reponse(data=users[0].to_json(), auth=users[0].get_id(), retmsg="Welcome back!")
if not users: return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
password = request.json.get('password')
try:
@ -97,28 +62,50 @@ def login():
@manager.route('/github_callback', methods=['GET'])
def github_callback():
try:
import requests
res = requests.post(GITHUB_OAUTH.get("url"), data={
"client_id": GITHUB_OAUTH.get("client_id"),
"client_secret": GITHUB_OAUTH.get("secret_key"),
"code": request.args.get('code')
},headers={"Accept": "application/json"})
res = res.json()
if "error" in res:
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
retmsg=res["error_description"])
import requests
res = requests.post(GITHUB_OAUTH.get("url"), data={
"client_id": GITHUB_OAUTH.get("client_id"),
"client_secret": GITHUB_OAUTH.get("secret_key"),
"code": request.args.get('code')
}, headers={"Accept": "application/json"})
res = res.json()
if "error" in res:
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
retmsg=res["error_description"])
if "user:email" not in res["scope"].split(","):
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope')
if "user:email" not in res["scope"].split(","):
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope')
session["access_token"] = res["access_token"]
session["access_token_from"] = "github"
return redirect(url_for("user.login"), code=307)
session["access_token"] = res["access_token"]
session["access_token_from"] = "github"
userinfo = user_info_from_github(session["access_token"])
users = UserService.query(email=userinfo["email"])
user_id = get_uuid()
if not users:
try:
try:
avatar = download_img(userinfo["avatar_url"])
except Exception as e:
stat_logger.exception(e)
avatar = ""
users = user_register(user_id, {
"access_token": session["access_token"],
"email": userinfo["email"],
"avatar": avatar,
"nickname": userinfo["login"],
"login_channel": "github",
"last_login_time": get_format_time(),
"is_superuser": False,
})
if not users: raise Exception('Register user failure.')
if len(users) > 1: raise Exception('Same E-mail exist!')
user = users[0]
login_user(user)
except Exception as e:
rollback_user_registration(user_id)
stat_logger.exception(e)
except Exception as e:
stat_logger.exception(e)
return server_error_response(e)
return redirect("/knowledge")
def user_info_from_github(access_token):
@ -208,7 +195,7 @@ def user_register(user_id, user):
for llm in LLMService.query(fid=LLM_FACTORY):
tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY})
if not UserService.insert(**user):return
if not UserService.save(**user):return
TenantService.insert(**tenant)
UserTenantService.insert(**usr_tenant)
TenantLLMService.insert_many(tenant_llm)

View File

@ -69,7 +69,6 @@ class TaskStatus(StrEnum):
class ParserType(StrEnum):
GENERAL = "general"
PRESENTATION = "presentation"
LAWS = "laws"
MANUAL = "manual"

View File

@ -475,7 +475,7 @@ class Knowledgebase(DataBaseModel):
similarity_threshold = FloatField(default=0.2)
vector_similarity_weight = FloatField(default=0.3)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.GENERAL.value)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value)
parser_config = JSONField(null=False, default={"pages":[[0,1000000]]})
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")

View File

@ -30,7 +30,7 @@ def init_superuser():
"password": "admin",
"nickname": "admin",
"is_superuser": True,
"email": "kai.hu@infiniflow.org",
"email": "admin@ragflow.io",
"creator": "system",
"status": "1",
}
@ -61,7 +61,7 @@ def init_superuser():
TenantService.insert(**tenant)
UserTenantService.insert(**usr_tenant)
TenantLLMService.insert_many(tenant_llm)
print("【INFO】Super user initialized. \033[93muser name: admin, password: admin\033[0m. Changing the password after logining is strongly recomanded.")
print("【INFO】Super user initialized. \033[93memail: admin@ragflow.io, password: admin\033[0m. Changing the password after logining is strongly recomanded.")
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import datetime
import peewee
from werkzeug.security import generate_password_hash, check_password_hash
@ -20,7 +22,7 @@ from api.db import UserTenantRole
from api.db.db_models import DB, UserTenant
from api.db.db_models import User, Tenant
from api.db.services.common_service import CommonService
from api.utils import get_uuid, get_format_time
from api.utils import get_uuid, get_format_time, current_timestamp, datetime_format
from api.db import StatusEnum
@ -53,6 +55,11 @@ class UserService(CommonService):
kwargs["id"] = get_uuid()
if "password" in kwargs:
kwargs["password"] = generate_password_hash(str(kwargs["password"]))
kwargs["create_time"] = current_timestamp()
kwargs["create_date"] = datetime_format(datetime.now())
kwargs["update_time"] = current_timestamp()
kwargs["update_date"] = datetime_format(datetime.now())
obj = cls.model(**kwargs).save(force_insert=True)
return obj
@ -66,10 +73,10 @@ class UserService(CommonService):
@classmethod
@DB.connection_context()
def update_user(cls, user_id, user_dict):
date_time = get_format_time()
with DB.atomic():
if user_dict:
user_dict["update_time"] = date_time
user_dict["update_time"] = current_timestamp()
user_dict["update_date"] = datetime_format(datetime.now())
cls.model.update(user_dict).where(cls.model.id == user_id).execute()

View File

@ -76,7 +76,7 @@ ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
API_KEY = LLM.get("api_key", "infiniflow API Key")
PARSERS = LLM.get("parsers", "general:General,qa:Q&A,resume:Resume,naive:Naive,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")
PARSERS = LLM.get("parsers", "naive:General,qa:Q&A,resume:Resume,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")
# distribution
DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)