init
This commit is contained in:
@@ -0,0 +1,379 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
|
||||
from app.core.schemas import DiagnosisDraft, DiagnosisQuestion, DiagnosisQuestionAnswer, DiagnosisResponse
|
||||
from app.core.settings import DEEPSEEK_API_KEY, DEEPSEEK_BASE_URL, DEEPSEEK_MODEL, DEEPSEEK_TIMEOUT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
SESSION_TTL = 3600 # 1 hour
|
||||
|
||||
CLASSIFY_SYSTEM_PROMPT = """你是宿舍维修故障诊断助手。根据学生故障描述,输出 JSON 对象。
|
||||
|
||||
分类规则:
|
||||
- 电路照明: 灯、电、跳闸、插座、焦味、漏电、火花
|
||||
- 给排水: 漏水、水龙头、下水、厕所、马桶、管道、洗手池
|
||||
- 门窗锁具: 门、锁、窗、把手、柜门
|
||||
- 空调设备: 空调、制冷、制热
|
||||
- 网络设备: 网络、wifi、网速、断网、网口
|
||||
- 家具设施: 床、桌椅、衣柜、抽屉、护栏
|
||||
- 其他: 不属于以上类别
|
||||
|
||||
紧急度规则(urgency):
|
||||
- 紧急: 火花、漏电、冒烟、大面积漏水、安全威胁
|
||||
- 高: 影响基本生活(照明/用水/用电/门禁)
|
||||
- 中: 普通故障影响使用
|
||||
- 低: 无立即影响
|
||||
|
||||
safety_risk: 火花、漏电、冒烟、焦味 设为 true
|
||||
suggested_worker: 电工(电路照明)、水暖维修(给排水)、门窗维修(门窗锁具)、空调维修(空调设备)、
|
||||
网络维护(网络设备)、综合维修(其他)
|
||||
questions: 2-3个有针对性、非模板化的追问,每个含 id(q1/q2/q3) 和 prompt 字段
|
||||
|
||||
输出格式:
|
||||
{"category":"...","urgency":"...","safety_risk":false,"suggested_worker":"...","notes":["..."],"questions":[{"id":"q1","prompt":"..."}]}"""
|
||||
|
||||
SUMMARY_SYSTEM_PROMPT = """根据学生原始描述和补充回答,生成维修工单摘要。
|
||||
格式:"[category]问题,建议[suggested_worker]处理。[现象];[位置];[影响范围]"
|
||||
不超过200字,使用中文标点。"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiagnosisSession:
|
||||
session_id: str
|
||||
category: str
|
||||
urgency: Literal["低", "中", "高", "紧急"]
|
||||
suggested_worker: str
|
||||
safety_risk: bool
|
||||
initial_message: str
|
||||
suggested_categories: list[str]
|
||||
questions: list[DiagnosisQuestion]
|
||||
notes: list[str] = field(default_factory=list)
|
||||
created_at: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
class AIDiagnosisProvider:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
base_url: str = DEEPSEEK_BASE_URL,
|
||||
api_key: str = DEEPSEEK_API_KEY,
|
||||
model: str = DEEPSEEK_MODEL,
|
||||
timeout: float = DEEPSEEK_TIMEOUT,
|
||||
) -> None:
|
||||
self._api_key = api_key
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=base_url.rstrip("/"),
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
timeout=httpx.Timeout(timeout),
|
||||
)
|
||||
self._model = model
|
||||
self._sessions: dict[str, DiagnosisSession] = {}
|
||||
self._last_sweep = time.time()
|
||||
|
||||
def _sweep_expired_sessions(self) -> None:
|
||||
now = time.time()
|
||||
if now - self._last_sweep < 300: # sweep every 5 minutes
|
||||
return
|
||||
self._last_sweep = now
|
||||
expired = [sid for sid, s in self._sessions.items() if now - s.created_at > SESSION_TTL]
|
||||
for sid in expired:
|
||||
del self._sessions[sid]
|
||||
if expired:
|
||||
logger.info("Swept %d expired diagnosis sessions", len(expired))
|
||||
|
||||
async def start(self, message: str) -> DiagnosisResponse:
|
||||
self._sweep_expired_sessions()
|
||||
if not self._api_key:
|
||||
return self._start_local(message)
|
||||
try:
|
||||
result = await self._call_classify(message)
|
||||
category = result["category"]
|
||||
urgency = result["urgency"]
|
||||
worker = result["suggested_worker"]
|
||||
safety_risk = result["safety_risk"]
|
||||
ai_notes = result.get("notes", [])
|
||||
questions = [DiagnosisQuestion(id=q["id"], prompt=q["prompt"]) for q in result.get("questions", [])]
|
||||
if not questions:
|
||||
questions = [
|
||||
DiagnosisQuestion(id="location", prompt="问题具体出现在宿舍的什么位置或设备上?"),
|
||||
DiagnosisQuestion(id="symptom", prompt="请补充最明显的故障现象,方便维修人员带对工具。"),
|
||||
]
|
||||
except (httpx.HTTPError, json.JSONDecodeError, KeyError):
|
||||
return self._start_local(message)
|
||||
|
||||
session_id = uuid4().hex
|
||||
suggested_categories = [category]
|
||||
if category != "其他":
|
||||
suggested_categories.append("其他")
|
||||
session = DiagnosisSession(
|
||||
session_id=session_id,
|
||||
category=category,
|
||||
urgency=urgency,
|
||||
suggested_worker=worker,
|
||||
safety_risk=safety_risk,
|
||||
initial_message=message,
|
||||
suggested_categories=suggested_categories,
|
||||
questions=questions,
|
||||
notes=ai_notes,
|
||||
)
|
||||
self._sessions[session_id] = session
|
||||
return DiagnosisResponse(
|
||||
session_id=session_id,
|
||||
stage="questions",
|
||||
initial_message=message,
|
||||
suggested_categories=suggested_categories,
|
||||
questions=questions,
|
||||
draft=None,
|
||||
)
|
||||
|
||||
async def answer(self, session_id: str, answers: list[DiagnosisQuestionAnswer]) -> DiagnosisResponse:
|
||||
session = self._sessions.get(session_id)
|
||||
if session is None:
|
||||
raise KeyError("session_not_found")
|
||||
|
||||
if time.time() - session.created_at > SESSION_TTL:
|
||||
del self._sessions[session_id]
|
||||
raise KeyError("session_not_found")
|
||||
|
||||
session.created_at = time.time()
|
||||
|
||||
if not self._api_key:
|
||||
return self._answer_local(session, answers)
|
||||
try:
|
||||
summary = await self._call_summary(session, answers)
|
||||
except (httpx.HTTPError, json.JSONDecodeError):
|
||||
return self._answer_local(session, answers)
|
||||
|
||||
notes = list(session.notes)
|
||||
if session.safety_risk:
|
||||
notes.append("建议尽快断电或停止继续使用相关设备。")
|
||||
answer_map = {item.question_id: item.answer.strip() for item in answers}
|
||||
for question in session.questions:
|
||||
value = answer_map.get(question.id)
|
||||
if value:
|
||||
notes.append(f"{question.prompt}:{value}")
|
||||
|
||||
draft = DiagnosisDraft(
|
||||
category=session.category,
|
||||
urgency=session.urgency,
|
||||
summary=summary,
|
||||
safety_risk=session.safety_risk,
|
||||
suggested_worker=session.suggested_worker,
|
||||
notes=notes,
|
||||
)
|
||||
return DiagnosisResponse(
|
||||
session_id=session.session_id,
|
||||
stage="draft",
|
||||
initial_message=session.initial_message,
|
||||
suggested_categories=session.suggested_categories,
|
||||
questions=[],
|
||||
draft=draft,
|
||||
)
|
||||
|
||||
def _start_local(self, message: str) -> DiagnosisResponse:
|
||||
category, urgency, worker, safety_risk, questions = self._local_classify(message)
|
||||
session_id = uuid4().hex
|
||||
suggested_categories = [category]
|
||||
if category != "其他":
|
||||
suggested_categories.append("其他")
|
||||
session = DiagnosisSession(
|
||||
session_id=session_id,
|
||||
category=category,
|
||||
urgency=urgency,
|
||||
suggested_worker=worker,
|
||||
safety_risk=safety_risk,
|
||||
initial_message=message,
|
||||
suggested_categories=suggested_categories,
|
||||
questions=questions,
|
||||
)
|
||||
self._sessions[session_id] = session
|
||||
return DiagnosisResponse(
|
||||
session_id=session_id,
|
||||
stage="questions",
|
||||
initial_message=message,
|
||||
suggested_categories=suggested_categories,
|
||||
questions=questions,
|
||||
draft=None,
|
||||
)
|
||||
|
||||
def _answer_local(self, session: DiagnosisSession, answers: list[DiagnosisQuestionAnswer]) -> DiagnosisResponse:
|
||||
answer_map = {item.question_id: item.answer.strip() for item in answers}
|
||||
notes = []
|
||||
if session.safety_risk:
|
||||
notes.append("建议尽快断电或停止继续使用相关设备。")
|
||||
if "location" in answer_map:
|
||||
notes.append(f"具体位置:{answer_map['location']}")
|
||||
if "impact" in answer_map:
|
||||
notes.append(f"影响范围:{answer_map['impact']}")
|
||||
if "symptom" in answer_map:
|
||||
notes.append(f"补充现象:{answer_map['symptom']}")
|
||||
|
||||
summary = self._local_build_summary(session, answer_map)
|
||||
draft = DiagnosisDraft(
|
||||
category=session.category,
|
||||
urgency=session.urgency,
|
||||
summary=summary,
|
||||
safety_risk=session.safety_risk,
|
||||
suggested_worker=session.suggested_worker,
|
||||
notes=notes,
|
||||
)
|
||||
return DiagnosisResponse(
|
||||
session_id=session.session_id,
|
||||
stage="draft",
|
||||
initial_message=session.initial_message,
|
||||
suggested_categories=session.suggested_categories,
|
||||
questions=[],
|
||||
draft=draft,
|
||||
)
|
||||
|
||||
def _local_classify(
|
||||
self, message: str
|
||||
) -> tuple[str, Literal["低", "中", "高", "紧急"], str, bool, list[DiagnosisQuestion]]:
|
||||
text = message.lower()
|
||||
category = "其他"
|
||||
urgency: Literal["低", "中", "高", "紧急"] = "中"
|
||||
worker = "综合维修"
|
||||
safety_risk = False
|
||||
|
||||
if any(keyword in text for keyword in ["灯", "电", "跳闸", "插座", "焦味"]):
|
||||
category = "电路照明"
|
||||
worker = "电工"
|
||||
urgency = "高"
|
||||
if any(keyword in text for keyword in ["漏水", "水龙头", "下水", "厕所", "马桶"]):
|
||||
category = "给排水"
|
||||
worker = "水暖维修"
|
||||
urgency = "高"
|
||||
if any(keyword in text for keyword in ["门", "锁", "窗", "把手"]):
|
||||
category = "门窗锁具"
|
||||
worker = "门窗维修"
|
||||
urgency = "中"
|
||||
if any(keyword in text for keyword in ["空调", "制冷", "制热"]):
|
||||
category = "空调设备"
|
||||
worker = "空调维修"
|
||||
urgency = "中"
|
||||
if any(keyword in text for keyword in ["网络", "wifi", "网速", "断网"]):
|
||||
category = "网络设备"
|
||||
worker = "网络维护"
|
||||
urgency = "中"
|
||||
if any(keyword in text for keyword in ["烟", "火花", "焦味", "漏电"]):
|
||||
urgency = "紧急"
|
||||
safety_risk = True
|
||||
|
||||
questions = [
|
||||
DiagnosisQuestion(id="location", prompt="问题具体出现在宿舍的什么位置或设备上?"),
|
||||
DiagnosisQuestion(id="impact", prompt="这个问题目前影响范围有多大?比如是否影响整间宿舍使用。"),
|
||||
DiagnosisQuestion(id="symptom", prompt="请补充最明显的故障现象,方便维修人员带对工具。"),
|
||||
]
|
||||
return category, urgency, worker, safety_risk, questions
|
||||
|
||||
def _local_build_summary(self, session: DiagnosisSession, answers: dict[str, str]) -> str:
|
||||
parts = [session.initial_message.strip()]
|
||||
for key in ("location", "impact", "symptom"):
|
||||
value = answers.get(key)
|
||||
if value:
|
||||
parts.append(value)
|
||||
joined = ";".join(parts)
|
||||
return f"{session.category}问题,建议{session.suggested_worker}处理。{joined}"
|
||||
|
||||
async def _call_classify(self, message: str) -> dict[str, Any]:
|
||||
response = await self._client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": self._model,
|
||||
"messages": [
|
||||
{"role": "system", "content": CLASSIFY_SYSTEM_PROMPT},
|
||||
{"role": "user", "content": message},
|
||||
],
|
||||
"response_format": {"type": "json_object"},
|
||||
"temperature": 0.1,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
body = response.json()
|
||||
content = body["choices"][0]["message"]["content"]
|
||||
result = json.loads(content)
|
||||
_validate_classify_result(result)
|
||||
return result
|
||||
|
||||
async def _call_summary(self, session: DiagnosisSession, answers: list[DiagnosisQuestionAnswer]) -> str:
|
||||
answer_texts = "\n".join(f"- {item.question_id}: {item.answer.strip()}" for item in answers)
|
||||
response = await self._client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": self._model,
|
||||
"messages": [
|
||||
{"role": "system", "content": SUMMARY_SYSTEM_PROMPT},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"原始描述:{session.initial_message}\n"
|
||||
f"故障类别:{session.category}\n"
|
||||
f"负责工种:{session.suggested_worker}\n"
|
||||
f"补充回答:\n{answer_texts}"
|
||||
),
|
||||
},
|
||||
],
|
||||
"temperature": 0.1,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
body = response.json()
|
||||
return body["choices"][0]["message"]["content"].strip()
|
||||
|
||||
|
||||
def _validate_classify_result(result: dict[str, Any]) -> None:
|
||||
valid_categories = {"电路照明", "给排水", "门窗锁具", "空调设备", "网络设备", "家具设施", "其他"}
|
||||
valid_urgencies = {"低", "中", "高", "紧急"}
|
||||
|
||||
category = result.get("category", "")
|
||||
if category not in valid_categories:
|
||||
result["category"] = "其他"
|
||||
|
||||
urgency = result.get("urgency", "")
|
||||
if urgency not in valid_urgencies:
|
||||
result["urgency"] = "中"
|
||||
|
||||
if not isinstance(result.get("safety_risk"), bool):
|
||||
result["safety_risk"] = False
|
||||
|
||||
worker = result.get("suggested_worker", "")
|
||||
if not worker:
|
||||
result["suggested_worker"] = "综合维修"
|
||||
|
||||
questions = result.get("questions")
|
||||
if not isinstance(questions, list):
|
||||
result["questions"] = []
|
||||
else:
|
||||
result["questions"] = [q for q in questions if isinstance(q, dict) and "id" in q and "prompt" in q]
|
||||
|
||||
notes = result.get("notes")
|
||||
if not isinstance(notes, list):
|
||||
result["notes"] = []
|
||||
|
||||
|
||||
_dp: AIDiagnosisProvider | None = None
|
||||
|
||||
|
||||
def get_diagnosis_provider() -> AIDiagnosisProvider:
|
||||
global _dp
|
||||
if _dp is None:
|
||||
_dp = AIDiagnosisProvider()
|
||||
return _dp
|
||||
|
||||
|
||||
async def close_diagnosis_provider() -> None:
|
||||
global _dp
|
||||
if _dp is not None:
|
||||
await _dp._client.aclose()
|
||||
_dp = None
|
||||
Reference in New Issue
Block a user