123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322 |
- import asyncio
- import logging
- import openai
- from fastapi import FastAPI, Request
- from fastapi.encoders import jsonable_encoder
- from fastapi.responses import JSONResponse
- from openai import OpenAI
- from pydantic import BaseModel
- from pymysql import OperationalError
- from starlette.middleware.cors import CORSMiddleware
- from LocalModel import CustomLogin, SaveUser, QueryUser, DeleteUser
- from db_decorator import peewee_db_close
- from logic import *
- from model import CustomUser, UserInfo, ZaneTest, database
- API_KEY = "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"
- app = FastAPI()
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- logging.basicConfig(
- level=logging.INFO, # 设置日志级别
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', # 日志格式
- datefmt='%Y-%m-%d %H:%M:%S', # 时间格式
- filename='app.log', # 日志文件存储位置
- filemode='a' # 文件模式,'a'为追加模式,默认为'a',还可以选择'w'覆写模式
- )
- def check_db_connect():
- try:
- database.connect(reuse_if_open=True)
- dt = CustomUser.select()
- for item in dt:
- print(item.id)
- except OperationalError as e:
- if 'MySQL server has gone away' in str(e):
- database.close()
- database.connect()
- logging.info("reconnect database")
- # threading.Timer(60 * 60, check_db_connect).start()
- # check_db_connect()
- class Question(BaseModel):
- user: str
- content: str
- stream: bool = True
- async def ai_stream(content: str):
- client = openai.OpenAI(api_key=API_KEY)
- completion = client.chat.completions.create(
- model="gpt-3.5-turbo",
- stream=True,
- messages=[
- {"role": "system", "content": content}
- ]
- )
- try:
- for chunk in completion:
- if chunk.choices[0].delta.content:
- yield chunk.choices[0].delta.content + "\n"
- await asyncio.sleep(0.01) # 稍微暂停以允许其他任务执行
- except Exception as e:
- yield f"Error: {e}\n"
- def ai_normal(content: str):
- client = OpenAI(api_key=API_KEY)
- completion = client.chat.completions.create(
- model="gpt-3.5-turbo",
- messages=[
- {"role": "system",
- "content": content},
- ]
- )
- return {"msg": completion.choices[0].message.content}
- # @app.post("/ai/")
- # async def do_ai(question: Question):
- # if question.stream:
- # return StreamingResponse(ai_stream(question.content), media_type="text/event-stream")
- # else:
- # return ai_normal(question.content)
- # def get_value(s: str):
- # return s
- # class MyRequest(BaseModel):
- # content: str
- # def test_func(arg1: str):
- # print(arg1)
- # return "nice"
- # @app.post("/func/")
- # async def call_func(mq: MyRequest):
- # client = OpenAI(api_key=API_KEY)
- # messages = []
- # messages.append({"role": "system",
- # "content": "You are a helpful assistant"})
- # messages.append({"role": "system",
- # "content": "If you need to call a function but you do not have enough parameters, ask the user to provide you with the missing parameters."})
- # messages.append({"role": "system",
- # "content": "You must answer in Chinese"})
- # messages.append({"role": "user", "content": mq.content})
- #
- # tools = [{
- # "type": "function",
- # "function": {
- # "name": "get_user_birthday",
- # "description": "get user birthday",
- # "parameters": {
- # "type": "object",
- # "properties": {
- # "birthday": {
- # "type": "string",
- # "description": "user birthday"
- # },
- # "city": {
- # "type": "string",
- # "description": "city of user born"
- # }
- # },
- # "required": ["birthday", "city"],
- # }
- # }
- # }]
- #
- # completion1 = client.chat.completions.create(
- # model="gpt-4",
- # messages=messages,
- # tools=tools
- # )
- # ast1 = completion1.choices[0].message
- # return {"msg": ast1}
- #
- # class YearInfo(BaseModel):
- # year: int
- # month: int
- # day: int
- # hour: int
- # minute: int
- #
- # @app.post("/wnl/add/")
- # async def add_wnl(info: YearInfo):
- # # result = []
- # # for year in range(info.from_year, info.to_year+1):
- # # for month in range(1, 13):
- # # max_day = 30
- # # if month == 2:
- # # if year % 4 == 0:
- # # max_day = 29
- # # else:
- # # max_day = 28
- # # elif month in (1, 3, 5, 7, 8, 10, 12):
- # # max_day = 31
- # # for day in range(1, max_day + 1):
- # # result.append({
- # # "nian": year,
- # # "yue": month,
- # # "ri": day
- # # })
- # # Wannianli.insert_many(result).execute()
- # wnl = Wannianli.select()
- # ct = len(wnl)
- # return {"data": "新增了" + str(ct) + "条数据"}
- #
- #
- # @app.post("/wnl/update/")
- # async def update_wnl(info: YearInfo):
- # data = get_wannianli_data(info.year, info.month, info.day)
- # msg = []
- # if data is not None:
- # msg = [data.nian_gan, data.nian_zhi,
- # data.yue_gan, data.yue_zhi,
- # data.ri_gan, data.ri_zhi]
- # hour_data = get_hour_of_day(data.ri_gan, info.hour)
- # msg.append(hour_data[0])
- # msg.append(hour_data[1])
- # return {"date": str(info.year) + "-" + str(info.month) + "-" + str(info.day) + " " + str(info.hour) + ":" + str(
- # info.minute),
- # "msg": msg}
- @app.post("/api/getSiZhuInfo")
- async def getSiZhuInfo(request: SiZhuInfoRequest):
- startDtm = None
- if request.mode == 2:
- startDtm = calc_date_of_sizhu(request)
- bazi = BaZi(request)
- dc = DataCenter(bazi)
- if startDtm is not None:
- bazi.taiyangshi = startDtm.__str__()
- fill_sizhu_in_bazi(bazi, dc)
- # logging.info("this is a info")
- # logging.info(jsonable_encoder(bazi))
- # print(jsonable_encoder(bazi))
- return jsonable_encoder(bazi)
- @app.post("/api/customLogin")
- async def customLogin(request: CustomLogin):
- logging.info("login")
- dt = CustomUser.select().where(CustomUser.user == request.user,
- CustomUser.psd == request.psd).first()
- if dt is not None:
- return {"msg": "ok", "name": dt.name, "sexy": dt.sexy}
- else:
- return {"msg": "error", "name": None, "sexy": None}
- @app.post("/api/saveUser")
- async def saveUser(request: SaveUser):
- ct = UserInfo.select().where(UserInfo.customer == request.customer).count()
- if ct >= 100:
- return {"msg": "超过可以保存的用户上限,请联系管理员", "state": -1}
- UserInfo.insert(request.to_db_data()).execute()
- return {"msg": "保存用户信息成功", "state": 200}
- def __build_user_object(dt: UserInfo):
- return {
- "id": dt.id,
- "name": dt.name,
- "beizhu": dt.beizhu,
- "isMan": bool(dt.man),
- "leibie": dt.leibie,
- "year": dt.year,
- "month": dt.month,
- "day": dt.day,
- "hour": dt.hour,
- "minute": dt.minute,
- "sheng": dt.sheng,
- "shi": dt.shi,
- "qu": dt.qu,
- "niangan": dt.niangan,
- "nianzhi": dt.nianzhi,
- "yuegan": dt.yuegan,
- "yuezhi": dt.yuezhi,
- "rigan": dt.rigan,
- "rizhi": dt.rizhi,
- "shigan": dt.shigan,
- "shizhi": dt.shizhi,
- "customer": dt.customer,
- "joinTime": dt.join_time
- }
- def __do_query_user(customer: str, filter: str):
- dts = UserInfo.select().where(UserInfo.customer == customer, UserInfo.enabled == 1)
- data = []
- if len(dts) > 0:
- for dt in dts:
- if filter is None:
- data.append(__build_user_object(dt))
- else:
- if filter in dt.name:
- data.append(__build_user_object(dt))
- return data
- @app.post("/api/queryUser")
- async def queryUser(request: QueryUser):
- data = __do_query_user(request.customer, request.filter)
- return jsonable_encoder(data)
- @app.post("/api/deleteUser")
- async def deleteUser(request: DeleteUser):
- UserInfo.update({"enabled": 0}).where(UserInfo.id == request.id).execute()
- return __do_query_user(request.customer, None)
- @app.post("/api/test")
- async def test(request: Request):
- request_origin = request.headers.get('origin')
- if request_origin is None:
- request_origin = "unknown"
- content = {"message": "Hello World" +
- request_origin, "db": "disconnect!!!"}
- headers = {'Access-Control-Allow-Origin': request_origin}
- content["db"] = "is_closed: " + \
- str(database.is_closed()) + " is_usable:" + \
- str(database.is_connection_usable())
- try:
- dt = CustomUser.select()
- for item in dt:
- print(item.name)
- except Exception as e:
- logging.info(e)
- return JSONResponse(content=content, headers=headers)
- @app.get("/api/test2")
- @peewee_db_close
- async def test2():
- zane_list = ZaneTest.select().where(ZaneTest.a != None)
- zane_list = [dt.a for dt in zane_list]
- return {"message": "Hello World", "users": zane_list}
|