Files
hci_work/backend/app/services/diagnosis.py
T
2026-06-06 23:54:11 +08:00

380 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
import logging
import time
from dataclasses import dataclass, field
from typing import Any, Literal
from uuid import uuid4
import httpx
from app.core.schemas import DiagnosisDraft, DiagnosisQuestion, DiagnosisQuestionAnswer, DiagnosisResponse
from app.core.settings import DEEPSEEK_API_KEY, DEEPSEEK_BASE_URL, DEEPSEEK_MODEL, DEEPSEEK_TIMEOUT
logger = logging.getLogger(__name__)
SESSION_TTL = 3600 # 1 hour
CLASSIFY_SYSTEM_PROMPT = """你是宿舍维修故障诊断助手。根据学生故障描述,输出 JSON 对象。
分类规则:
- 电路照明: 灯、电、跳闸、插座、焦味、漏电、火花
- 给排水: 漏水、水龙头、下水、厕所、马桶、管道、洗手池
- 门窗锁具: 门、锁、窗、把手、柜门
- 空调设备: 空调、制冷、制热
- 网络设备: 网络、wifi、网速、断网、网口
- 家具设施: 床、桌椅、衣柜、抽屉、护栏
- 其他: 不属于以上类别
紧急度规则(urgency):
- 紧急: 火花、漏电、冒烟、大面积漏水、安全威胁
- 高: 影响基本生活(照明/用水/用电/门禁)
- 中: 普通故障影响使用
- 低: 无立即影响
safety_risk: 火花、漏电、冒烟、焦味 设为 true
suggested_worker: 电工(电路照明)、水暖维修(给排水)、门窗维修(门窗锁具)、空调维修(空调设备)、
网络维护(网络设备)、综合维修(其他)
questions: 2-3个有针对性、非模板化的追问,每个含 id(q1/q2/q3) 和 prompt 字段
输出格式:
{"category":"...","urgency":"...","safety_risk":false,"suggested_worker":"...","notes":["..."],"questions":[{"id":"q1","prompt":"..."}]}"""
SUMMARY_SYSTEM_PROMPT = """根据学生原始描述和补充回答,生成维修工单摘要。
格式:"[category]问题,建议[suggested_worker]处理。[现象][位置][影响范围]"
不超过200字,使用中文标点。"""
@dataclass
class DiagnosisSession:
session_id: str
category: str
urgency: Literal["", "", "", "紧急"]
suggested_worker: str
safety_risk: bool
initial_message: str
suggested_categories: list[str]
questions: list[DiagnosisQuestion]
notes: list[str] = field(default_factory=list)
created_at: float = field(default_factory=time.time)
class AIDiagnosisProvider:
def __init__(
self,
*,
base_url: str = DEEPSEEK_BASE_URL,
api_key: str = DEEPSEEK_API_KEY,
model: str = DEEPSEEK_MODEL,
timeout: float = DEEPSEEK_TIMEOUT,
) -> None:
self._api_key = api_key
self._client = httpx.AsyncClient(
base_url=base_url.rstrip("/"),
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
timeout=httpx.Timeout(timeout),
)
self._model = model
self._sessions: dict[str, DiagnosisSession] = {}
self._last_sweep = time.time()
def _sweep_expired_sessions(self) -> None:
now = time.time()
if now - self._last_sweep < 300: # sweep every 5 minutes
return
self._last_sweep = now
expired = [sid for sid, s in self._sessions.items() if now - s.created_at > SESSION_TTL]
for sid in expired:
del self._sessions[sid]
if expired:
logger.info("Swept %d expired diagnosis sessions", len(expired))
async def start(self, message: str) -> DiagnosisResponse:
self._sweep_expired_sessions()
if not self._api_key:
return self._start_local(message)
try:
result = await self._call_classify(message)
category = result["category"]
urgency = result["urgency"]
worker = result["suggested_worker"]
safety_risk = result["safety_risk"]
ai_notes = result.get("notes", [])
questions = [DiagnosisQuestion(id=q["id"], prompt=q["prompt"]) for q in result.get("questions", [])]
if not questions:
questions = [
DiagnosisQuestion(id="location", prompt="问题具体出现在宿舍的什么位置或设备上?"),
DiagnosisQuestion(id="symptom", prompt="请补充最明显的故障现象,方便维修人员带对工具。"),
]
except (httpx.HTTPError, json.JSONDecodeError, KeyError):
return self._start_local(message)
session_id = uuid4().hex
suggested_categories = [category]
if category != "其他":
suggested_categories.append("其他")
session = DiagnosisSession(
session_id=session_id,
category=category,
urgency=urgency,
suggested_worker=worker,
safety_risk=safety_risk,
initial_message=message,
suggested_categories=suggested_categories,
questions=questions,
notes=ai_notes,
)
self._sessions[session_id] = session
return DiagnosisResponse(
session_id=session_id,
stage="questions",
initial_message=message,
suggested_categories=suggested_categories,
questions=questions,
draft=None,
)
async def answer(self, session_id: str, answers: list[DiagnosisQuestionAnswer]) -> DiagnosisResponse:
session = self._sessions.get(session_id)
if session is None:
raise KeyError("session_not_found")
if time.time() - session.created_at > SESSION_TTL:
del self._sessions[session_id]
raise KeyError("session_not_found")
session.created_at = time.time()
if not self._api_key:
return self._answer_local(session, answers)
try:
summary = await self._call_summary(session, answers)
except (httpx.HTTPError, json.JSONDecodeError):
return self._answer_local(session, answers)
notes = list(session.notes)
if session.safety_risk:
notes.append("建议尽快断电或停止继续使用相关设备。")
answer_map = {item.question_id: item.answer.strip() for item in answers}
for question in session.questions:
value = answer_map.get(question.id)
if value:
notes.append(f"{question.prompt}{value}")
draft = DiagnosisDraft(
category=session.category,
urgency=session.urgency,
summary=summary,
safety_risk=session.safety_risk,
suggested_worker=session.suggested_worker,
notes=notes,
)
return DiagnosisResponse(
session_id=session.session_id,
stage="draft",
initial_message=session.initial_message,
suggested_categories=session.suggested_categories,
questions=[],
draft=draft,
)
def _start_local(self, message: str) -> DiagnosisResponse:
category, urgency, worker, safety_risk, questions = self._local_classify(message)
session_id = uuid4().hex
suggested_categories = [category]
if category != "其他":
suggested_categories.append("其他")
session = DiagnosisSession(
session_id=session_id,
category=category,
urgency=urgency,
suggested_worker=worker,
safety_risk=safety_risk,
initial_message=message,
suggested_categories=suggested_categories,
questions=questions,
)
self._sessions[session_id] = session
return DiagnosisResponse(
session_id=session_id,
stage="questions",
initial_message=message,
suggested_categories=suggested_categories,
questions=questions,
draft=None,
)
def _answer_local(self, session: DiagnosisSession, answers: list[DiagnosisQuestionAnswer]) -> DiagnosisResponse:
answer_map = {item.question_id: item.answer.strip() for item in answers}
notes = []
if session.safety_risk:
notes.append("建议尽快断电或停止继续使用相关设备。")
if "location" in answer_map:
notes.append(f"具体位置:{answer_map['location']}")
if "impact" in answer_map:
notes.append(f"影响范围:{answer_map['impact']}")
if "symptom" in answer_map:
notes.append(f"补充现象:{answer_map['symptom']}")
summary = self._local_build_summary(session, answer_map)
draft = DiagnosisDraft(
category=session.category,
urgency=session.urgency,
summary=summary,
safety_risk=session.safety_risk,
suggested_worker=session.suggested_worker,
notes=notes,
)
return DiagnosisResponse(
session_id=session.session_id,
stage="draft",
initial_message=session.initial_message,
suggested_categories=session.suggested_categories,
questions=[],
draft=draft,
)
def _local_classify(
self, message: str
) -> tuple[str, Literal["", "", "", "紧急"], str, bool, list[DiagnosisQuestion]]:
text = message.lower()
category = "其他"
urgency: Literal["", "", "", "紧急"] = ""
worker = "综合维修"
safety_risk = False
if any(keyword in text for keyword in ["", "", "跳闸", "插座", "焦味"]):
category = "电路照明"
worker = "电工"
urgency = ""
if any(keyword in text for keyword in ["漏水", "水龙头", "下水", "厕所", "马桶"]):
category = "给排水"
worker = "水暖维修"
urgency = ""
if any(keyword in text for keyword in ["", "", "", "把手"]):
category = "门窗锁具"
worker = "门窗维修"
urgency = ""
if any(keyword in text for keyword in ["空调", "制冷", "制热"]):
category = "空调设备"
worker = "空调维修"
urgency = ""
if any(keyword in text for keyword in ["网络", "wifi", "网速", "断网"]):
category = "网络设备"
worker = "网络维护"
urgency = ""
if any(keyword in text for keyword in ["", "火花", "焦味", "漏电"]):
urgency = "紧急"
safety_risk = True
questions = [
DiagnosisQuestion(id="location", prompt="问题具体出现在宿舍的什么位置或设备上?"),
DiagnosisQuestion(id="impact", prompt="这个问题目前影响范围有多大?比如是否影响整间宿舍使用。"),
DiagnosisQuestion(id="symptom", prompt="请补充最明显的故障现象,方便维修人员带对工具。"),
]
return category, urgency, worker, safety_risk, questions
def _local_build_summary(self, session: DiagnosisSession, answers: dict[str, str]) -> str:
parts = [session.initial_message.strip()]
for key in ("location", "impact", "symptom"):
value = answers.get(key)
if value:
parts.append(value)
joined = "".join(parts)
return f"{session.category}问题,建议{session.suggested_worker}处理。{joined}"
async def _call_classify(self, message: str) -> dict[str, Any]:
response = await self._client.post(
"/v1/chat/completions",
json={
"model": self._model,
"messages": [
{"role": "system", "content": CLASSIFY_SYSTEM_PROMPT},
{"role": "user", "content": message},
],
"response_format": {"type": "json_object"},
"temperature": 0.1,
},
)
response.raise_for_status()
body = response.json()
content = body["choices"][0]["message"]["content"]
result = json.loads(content)
_validate_classify_result(result)
return result
async def _call_summary(self, session: DiagnosisSession, answers: list[DiagnosisQuestionAnswer]) -> str:
answer_texts = "\n".join(f"- {item.question_id}: {item.answer.strip()}" for item in answers)
response = await self._client.post(
"/v1/chat/completions",
json={
"model": self._model,
"messages": [
{"role": "system", "content": SUMMARY_SYSTEM_PROMPT},
{
"role": "user",
"content": (
f"原始描述:{session.initial_message}\n"
f"故障类别:{session.category}\n"
f"负责工种:{session.suggested_worker}\n"
f"补充回答:\n{answer_texts}"
),
},
],
"temperature": 0.1,
},
)
response.raise_for_status()
body = response.json()
return body["choices"][0]["message"]["content"].strip()
def _validate_classify_result(result: dict[str, Any]) -> None:
valid_categories = {"电路照明", "给排水", "门窗锁具", "空调设备", "网络设备", "家具设施", "其他"}
valid_urgencies = {"", "", "", "紧急"}
category = result.get("category", "")
if category not in valid_categories:
result["category"] = "其他"
urgency = result.get("urgency", "")
if urgency not in valid_urgencies:
result["urgency"] = ""
if not isinstance(result.get("safety_risk"), bool):
result["safety_risk"] = False
worker = result.get("suggested_worker", "")
if not worker:
result["suggested_worker"] = "综合维修"
questions = result.get("questions")
if not isinstance(questions, list):
result["questions"] = []
else:
result["questions"] = [q for q in questions if isinstance(q, dict) and "id" in q and "prompt" in q]
notes = result.get("notes")
if not isinstance(notes, list):
result["notes"] = []
_dp: AIDiagnosisProvider | None = None
def get_diagnosis_provider() -> AIDiagnosisProvider:
global _dp
if _dp is None:
_dp = AIDiagnosisProvider()
return _dp
async def close_diagnosis_provider() -> None:
global _dp
if _dp is not None:
await _dp._client.aclose()
_dp = None