main.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. import asyncio
  2. import logging
  3. import openai
  4. from fastapi import FastAPI, Request
  5. from fastapi.encoders import jsonable_encoder
  6. from fastapi.responses import JSONResponse
  7. from openai import OpenAI
  8. from starlette.middleware.cors import CORSMiddleware
  9. from DBTools import DBUserInfo, DBCustomUser
  10. from LocalModel import CustomLogin, SaveUser, QueryUser, DeleteUser
  11. from logic import *
  12. API_KEY = "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"
  13. app = FastAPI()
  14. app.add_middleware(
  15. CORSMiddleware,
  16. allow_origins=["*"],
  17. allow_credentials=True,
  18. allow_methods=["*"],
  19. allow_headers=["*"],
  20. )
  21. logging.basicConfig(
  22. level=logging.INFO, # 设置日志级别
  23. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', # 日志格式
  24. datefmt='%Y-%m-%d %H:%M:%S', # 时间格式
  25. filename='app.log', # 日志文件存储位置
  26. filemode='a' # 文件模式,'a'为追加模式,默认为'a',还可以选择'w'覆写模式
  27. )
  28. async def ai_stream(content: str):
  29. client = openai.OpenAI(api_key=API_KEY)
  30. completion = client.chat.completions.create(
  31. model="gpt-3.5-turbo",
  32. stream=True,
  33. messages=[
  34. {"role": "system", "content": content}
  35. ]
  36. )
  37. try:
  38. for chunk in completion:
  39. if chunk.choices[0].delta.content:
  40. yield chunk.choices[0].delta.content + "\n"
  41. await asyncio.sleep(0.01) # 稍微暂停以允许其他任务执行
  42. except Exception as e:
  43. yield f"Error: {e}\n"
  44. def ai_normal(content: str):
  45. client = OpenAI(api_key=API_KEY)
  46. completion = client.chat.completions.create(
  47. model="gpt-3.5-turbo",
  48. messages=[
  49. {"role": "system",
  50. "content": content},
  51. ]
  52. )
  53. return {"msg": completion.choices[0].message.content}
  54. # @app.post("/ai/")
  55. # async def do_ai(question: Question):
  56. # if question.stream:
  57. # return StreamingResponse(ai_stream(question.content), media_type="text/event-stream")
  58. # else:
  59. # return ai_normal(question.content)
  60. # def get_value(s: str):
  61. # return s
  62. # class MyRequest(BaseModel):
  63. # content: str
  64. # def test_func(arg1: str):
  65. # print(arg1)
  66. # return "nice"
  67. # @app.post("/func/")
  68. # async def call_func(mq: MyRequest):
  69. # client = OpenAI(api_key=API_KEY)
  70. # messages = []
  71. # messages.append({"role": "system",
  72. # "content": "You are a helpful assistant"})
  73. # messages.append({"role": "system",
  74. # "content": "If you need to call a function but you do not have enough parameters, ask the user to provide you with the missing parameters."})
  75. # messages.append({"role": "system",
  76. # "content": "You must answer in Chinese"})
  77. # messages.append({"role": "user", "content": mq.content})
  78. #
  79. # tools = [{
  80. # "type": "function",
  81. # "function": {
  82. # "name": "get_user_birthday",
  83. # "description": "get user birthday",
  84. # "parameters": {
  85. # "type": "object",
  86. # "properties": {
  87. # "birthday": {
  88. # "type": "string",
  89. # "description": "user birthday"
  90. # },
  91. # "city": {
  92. # "type": "string",
  93. # "description": "city of user born"
  94. # }
  95. # },
  96. # "required": ["birthday", "city"],
  97. # }
  98. # }
  99. # }]
  100. #
  101. # completion1 = client.chat.completions.create(
  102. # model="gpt-4",
  103. # messages=messages,
  104. # tools=tools
  105. # )
  106. # ast1 = completion1.choices[0].message
  107. # return {"msg": ast1}
  108. #
  109. # class YearInfo(BaseModel):
  110. # year: int
  111. # month: int
  112. # day: int
  113. # hour: int
  114. # minute: int
  115. #
  116. # @app.post("/wnl/add/")
  117. # async def add_wnl(info: YearInfo):
  118. # # result = []
  119. # # for year in range(info.from_year, info.to_year+1):
  120. # # for month in range(1, 13):
  121. # # max_day = 30
  122. # # if month == 2:
  123. # # if year % 4 == 0:
  124. # # max_day = 29
  125. # # else:
  126. # # max_day = 28
  127. # # elif month in (1, 3, 5, 7, 8, 10, 12):
  128. # # max_day = 31
  129. # # for day in range(1, max_day + 1):
  130. # # result.append({
  131. # # "nian": year,
  132. # # "yue": month,
  133. # # "ri": day
  134. # # })
  135. # # Wannianli.insert_many(result).execute()
  136. # wnl = Wannianli.select()
  137. # ct = len(wnl)
  138. # return {"data": "新增了" + str(ct) + "条数据"}
  139. #
  140. #
  141. # @app.post("/wnl/update/")
  142. # async def update_wnl(info: YearInfo):
  143. # data = get_wannianli_data(info.year, info.month, info.day)
  144. # msg = []
  145. # if data is not None:
  146. # msg = [data.nian_gan, data.nian_zhi,
  147. # data.yue_gan, data.yue_zhi,
  148. # data.ri_gan, data.ri_zhi]
  149. # hour_data = get_hour_of_day(data.ri_gan, info.hour)
  150. # msg.append(hour_data[0])
  151. # msg.append(hour_data[1])
  152. # return {"date": str(info.year) + "-" + str(info.month) + "-" + str(info.day) + " " + str(info.hour) + ":" + str(
  153. # info.minute),
  154. # "msg": msg}
  155. @app.post("/api/getSiZhuInfo")
  156. async def getSiZhuInfo(request: SiZhuInfoRequest):
  157. startDtm = None
  158. if request.mode == 2:
  159. startDtm = calc_date_of_sizhu(request)
  160. bazi = BaZi(request)
  161. dc = DataCenter(bazi)
  162. if startDtm is not None:
  163. bazi.taiyangshi = startDtm.__str__()
  164. fill_sizhu_in_bazi(bazi, dc)
  165. # logging.info("this is a info")
  166. # logging.info(jsonable_encoder(bazi))
  167. # print(jsonable_encoder(bazi))
  168. return jsonable_encoder(bazi)
  169. @app.post("/api/customLogin")
  170. async def customLogin(request: CustomLogin):
  171. logging.info("login")
  172. dt = DBCustomUser.query_first_by(user=request.user, psd=request.psd)
  173. if dt is not None:
  174. return {"msg": "ok", "name": dt.name, "sexy": dt.sexy}
  175. else:
  176. return {"msg": "error", "name": None, "sexy": None}
  177. @app.post("/api/saveUser")
  178. async def saveUser(request: SaveUser):
  179. dts = DBUserInfo.query_by(customer=request.customer)
  180. if len(dts) >= 100:
  181. return {"msg": "超过可以保存的用户上限,请联系管理员", "state": -1}
  182. from datetime import datetime
  183. tm = datetime.now()
  184. DBUserInfo.insert(name=request.name,
  185. beizhu=request.beizhu,
  186. man=request.isMan,
  187. leibie=request.leibie,
  188. year=request.year,
  189. month=request.month,
  190. day=request.day,
  191. hour=request.hour,
  192. minute=request.minute,
  193. sheng=request.sheng,
  194. shi=request.shi,
  195. qu=request.qu,
  196. niangan=request.niangan,
  197. nianzhi=request.nianzhi,
  198. yuegan=request.yuegan,
  199. yuezhi=request.yuezhi,
  200. rigan=request.rigan,
  201. rizhi=request.rizhi,
  202. shigan=request.shigan,
  203. shizhi=request.shizhi,
  204. customer=request.customer,
  205. joinTime=tm.strftime("%Y-%m-%d %H:%M:%S"),
  206. enabled=1,
  207. )
  208. return {"msg": "保存用户信息成功", "state": 200}
  209. def __build_user_object1(dt: DBUserInfo):
  210. return {
  211. "id": dt.id,
  212. "name": dt.name,
  213. "beizhu": dt.beizhu,
  214. "isMan": bool(dt.man),
  215. "leibie": dt.leibie,
  216. "year": dt.year,
  217. "month": dt.month,
  218. "day": dt.day,
  219. "hour": dt.hour,
  220. "minute": dt.minute,
  221. "sheng": dt.sheng,
  222. "shi": dt.shi,
  223. "qu": dt.qu,
  224. "niangan": dt.niangan,
  225. "nianzhi": dt.nianzhi,
  226. "yuegan": dt.yuegan,
  227. "yuezhi": dt.yuezhi,
  228. "rigan": dt.rigan,
  229. "rizhi": dt.rizhi,
  230. "shigan": dt.shigan,
  231. "shizhi": dt.shizhi,
  232. "customer": dt.customer,
  233. "joinTime": dt.joinTime
  234. }
  235. def __do_query_user(customer: str, filter: str):
  236. dts = DBUserInfo.query_by(customer=customer, enabled=1)
  237. data = []
  238. if len(dts) > 0:
  239. for dt in dts:
  240. if filter is None:
  241. data.append(__build_user_object1(dt))
  242. else:
  243. if filter in dt.name:
  244. data.append(__build_user_object1(dt))
  245. return data
  246. @app.post("/api/queryUser")
  247. async def queryUser(request: QueryUser):
  248. data = __do_query_user(request.customer, request.filter)
  249. return jsonable_encoder(data)
  250. @app.post("/api/deleteUser")
  251. async def deleteUser(request: DeleteUser):
  252. ss = DBUserInfo.new_session()
  253. data = ss.query(DBUserInfo).filter_by(id=request.id).first()
  254. if data is not None:
  255. data.enabled = 0
  256. ss.commit()
  257. ss.close()
  258. return __do_query_user(request.customer, None)
  259. @app.post("/api/test")
  260. async def test(request: Request):
  261. # ContentLogic.createContent()
  262. return JSONResponse(content="ok")
  263. @app.get("/api/test2")
  264. async def test2():
  265. return {"message": "Hello World"}