main.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. import asyncio
  2. import logging
  3. import openai
  4. from fastapi import FastAPI, Request
  5. from fastapi.encoders import jsonable_encoder
  6. from fastapi.responses import JSONResponse
  7. from openai import OpenAI
  8. from pydantic import BaseModel
  9. from pymysql import OperationalError
  10. from starlette.middleware.cors import CORSMiddleware
  11. from LocalModel import CustomLogin, SaveUser, QueryUser, DeleteUser
  12. from db_decorator import peewee_db_close
  13. from logic import *
  14. from model import CustomUser, UserInfo, ZaneTest, database
  15. API_KEY = "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"
  16. app = FastAPI()
  17. app.add_middleware(
  18. CORSMiddleware,
  19. allow_origins=["*"],
  20. allow_credentials=True,
  21. allow_methods=["*"],
  22. allow_headers=["*"],
  23. )
  24. logging.basicConfig(
  25. level=logging.INFO, # 设置日志级别
  26. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', # 日志格式
  27. datefmt='%Y-%m-%d %H:%M:%S', # 时间格式
  28. filename='app.log', # 日志文件存储位置
  29. filemode='a' # 文件模式,'a'为追加模式,默认为'a',还可以选择'w'覆写模式
  30. )
  31. def check_db_connect():
  32. try:
  33. database.connect(reuse_if_open=True)
  34. dt = CustomUser.select()
  35. for item in dt:
  36. print(item.id)
  37. except OperationalError as e:
  38. if 'MySQL server has gone away' in str(e):
  39. database.close()
  40. database.connect()
  41. logging.info("reconnect database")
  42. # threading.Timer(60 * 60, check_db_connect).start()
  43. # check_db_connect()
  44. class Question(BaseModel):
  45. user: str
  46. content: str
  47. stream: bool = True
  48. async def ai_stream(content: str):
  49. client = openai.OpenAI(api_key=API_KEY)
  50. completion = client.chat.completions.create(
  51. model="gpt-3.5-turbo",
  52. stream=True,
  53. messages=[
  54. {"role": "system", "content": content}
  55. ]
  56. )
  57. try:
  58. for chunk in completion:
  59. if chunk.choices[0].delta.content:
  60. yield chunk.choices[0].delta.content + "\n"
  61. await asyncio.sleep(0.01) # 稍微暂停以允许其他任务执行
  62. except Exception as e:
  63. yield f"Error: {e}\n"
  64. def ai_normal(content: str):
  65. client = OpenAI(api_key=API_KEY)
  66. completion = client.chat.completions.create(
  67. model="gpt-3.5-turbo",
  68. messages=[
  69. {"role": "system",
  70. "content": content},
  71. ]
  72. )
  73. return {"msg": completion.choices[0].message.content}
  74. # @app.post("/ai/")
  75. # async def do_ai(question: Question):
  76. # if question.stream:
  77. # return StreamingResponse(ai_stream(question.content), media_type="text/event-stream")
  78. # else:
  79. # return ai_normal(question.content)
  80. # def get_value(s: str):
  81. # return s
  82. # class MyRequest(BaseModel):
  83. # content: str
  84. # def test_func(arg1: str):
  85. # print(arg1)
  86. # return "nice"
  87. # @app.post("/func/")
  88. # async def call_func(mq: MyRequest):
  89. # client = OpenAI(api_key=API_KEY)
  90. # messages = []
  91. # messages.append({"role": "system",
  92. # "content": "You are a helpful assistant"})
  93. # messages.append({"role": "system",
  94. # "content": "If you need to call a function but you do not have enough parameters, ask the user to provide you with the missing parameters."})
  95. # messages.append({"role": "system",
  96. # "content": "You must answer in Chinese"})
  97. # messages.append({"role": "user", "content": mq.content})
  98. #
  99. # tools = [{
  100. # "type": "function",
  101. # "function": {
  102. # "name": "get_user_birthday",
  103. # "description": "get user birthday",
  104. # "parameters": {
  105. # "type": "object",
  106. # "properties": {
  107. # "birthday": {
  108. # "type": "string",
  109. # "description": "user birthday"
  110. # },
  111. # "city": {
  112. # "type": "string",
  113. # "description": "city of user born"
  114. # }
  115. # },
  116. # "required": ["birthday", "city"],
  117. # }
  118. # }
  119. # }]
  120. #
  121. # completion1 = client.chat.completions.create(
  122. # model="gpt-4",
  123. # messages=messages,
  124. # tools=tools
  125. # )
  126. # ast1 = completion1.choices[0].message
  127. # return {"msg": ast1}
  128. #
  129. # class YearInfo(BaseModel):
  130. # year: int
  131. # month: int
  132. # day: int
  133. # hour: int
  134. # minute: int
  135. #
  136. # @app.post("/wnl/add/")
  137. # async def add_wnl(info: YearInfo):
  138. # # result = []
  139. # # for year in range(info.from_year, info.to_year+1):
  140. # # for month in range(1, 13):
  141. # # max_day = 30
  142. # # if month == 2:
  143. # # if year % 4 == 0:
  144. # # max_day = 29
  145. # # else:
  146. # # max_day = 28
  147. # # elif month in (1, 3, 5, 7, 8, 10, 12):
  148. # # max_day = 31
  149. # # for day in range(1, max_day + 1):
  150. # # result.append({
  151. # # "nian": year,
  152. # # "yue": month,
  153. # # "ri": day
  154. # # })
  155. # # Wannianli.insert_many(result).execute()
  156. # wnl = Wannianli.select()
  157. # ct = len(wnl)
  158. # return {"data": "新增了" + str(ct) + "条数据"}
  159. #
  160. #
  161. # @app.post("/wnl/update/")
  162. # async def update_wnl(info: YearInfo):
  163. # data = get_wannianli_data(info.year, info.month, info.day)
  164. # msg = []
  165. # if data is not None:
  166. # msg = [data.nian_gan, data.nian_zhi,
  167. # data.yue_gan, data.yue_zhi,
  168. # data.ri_gan, data.ri_zhi]
  169. # hour_data = get_hour_of_day(data.ri_gan, info.hour)
  170. # msg.append(hour_data[0])
  171. # msg.append(hour_data[1])
  172. # return {"date": str(info.year) + "-" + str(info.month) + "-" + str(info.day) + " " + str(info.hour) + ":" + str(
  173. # info.minute),
  174. # "msg": msg}
  175. @app.post("/api/getSiZhuInfo")
  176. async def getSiZhuInfo(request: SiZhuInfoRequest):
  177. startDtm = None
  178. if request.mode == 2:
  179. startDtm = calc_date_of_sizhu(request)
  180. bazi = BaZi(request)
  181. dc = DataCenter(bazi)
  182. if startDtm is not None:
  183. bazi.taiyangshi = startDtm.__str__()
  184. fill_sizhu_in_bazi(bazi, dc)
  185. # logging.info("this is a info")
  186. # logging.info(jsonable_encoder(bazi))
  187. # print(jsonable_encoder(bazi))
  188. return jsonable_encoder(bazi)
  189. @app.post("/api/customLogin")
  190. async def customLogin(request: CustomLogin):
  191. logging.info("login")
  192. dt = CustomUser.select().where(CustomUser.user == request.user,
  193. CustomUser.psd == request.psd).first()
  194. if dt is not None:
  195. return {"msg": "ok", "name": dt.name, "sexy": dt.sexy}
  196. else:
  197. return {"msg": "error", "name": None, "sexy": None}
  198. @app.post("/api/saveUser")
  199. async def saveUser(request: SaveUser):
  200. ct = UserInfo.select().where(UserInfo.customer == request.customer).count()
  201. if ct >= 100:
  202. return {"msg": "超过可以保存的用户上限,请联系管理员", "state": -1}
  203. UserInfo.insert(request.to_db_data()).execute()
  204. return {"msg": "保存用户信息成功", "state": 200}
  205. def __build_user_object(dt: UserInfo):
  206. return {
  207. "id": dt.id,
  208. "name": dt.name,
  209. "beizhu": dt.beizhu,
  210. "isMan": bool(dt.man),
  211. "leibie": dt.leibie,
  212. "year": dt.year,
  213. "month": dt.month,
  214. "day": dt.day,
  215. "hour": dt.hour,
  216. "minute": dt.minute,
  217. "sheng": dt.sheng,
  218. "shi": dt.shi,
  219. "qu": dt.qu,
  220. "niangan": dt.niangan,
  221. "nianzhi": dt.nianzhi,
  222. "yuegan": dt.yuegan,
  223. "yuezhi": dt.yuezhi,
  224. "rigan": dt.rigan,
  225. "rizhi": dt.rizhi,
  226. "shigan": dt.shigan,
  227. "shizhi": dt.shizhi,
  228. "customer": dt.customer,
  229. "joinTime": dt.join_time
  230. }
  231. def __do_query_user(customer: str, filter: str):
  232. dts = UserInfo.select().where(UserInfo.customer == customer, UserInfo.enabled == 1)
  233. data = []
  234. if len(dts) > 0:
  235. for dt in dts:
  236. if filter is None:
  237. data.append(__build_user_object(dt))
  238. else:
  239. if filter in dt.name:
  240. data.append(__build_user_object(dt))
  241. return data
  242. @app.post("/api/queryUser")
  243. async def queryUser(request: QueryUser):
  244. data = __do_query_user(request.customer, request.filter)
  245. return jsonable_encoder(data)
  246. @app.post("/api/deleteUser")
  247. async def deleteUser(request: DeleteUser):
  248. UserInfo.update({"enabled": 0}).where(UserInfo.id == request.id).execute()
  249. return __do_query_user(request.customer, None)
  250. @app.post("/api/test")
  251. async def test(request: Request):
  252. request_origin = request.headers.get('origin')
  253. if request_origin is None:
  254. request_origin = "unknown"
  255. content = {"message": "Hello World" +
  256. request_origin, "db": "disconnect!!!"}
  257. headers = {'Access-Control-Allow-Origin': request_origin}
  258. content["db"] = "is_closed: " + \
  259. str(database.is_closed()) + " is_usable:" + \
  260. str(database.is_connection_usable())
  261. try:
  262. dt = CustomUser.select()
  263. for item in dt:
  264. print(item.name)
  265. except Exception as e:
  266. logging.info(e)
  267. return JSONResponse(content=content, headers=headers)
  268. @app.get("/api/test2")
  269. @peewee_db_close
  270. async def test2():
  271. zane_list = ZaneTest.select().where(ZaneTest.a != None)
  272. zane_list = [dt.a for dt in zane_list]
  273. return {"message": "Hello World", "users": zane_list}