1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
| import os import sys import glob import os import warnings
from langchain_core._api import LangChainDeprecationWarning warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
from langchain.agents import initialize_agent, Tool from langchain.chains import RetrievalQA
from langchain_ollama import ChatOllama from langchain_chroma import Chroma from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.document_loaders import TextLoader from langchain.text_splitter import CharacterTextSplitter from langchain_community.tools import DuckDuckGoSearchRun
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
from tools.medical_tools import MedicalTools
class MedicalAgent: def __init__(self): self.llm = ChatOllama( model="gemma3:4b", temperature=0, base_url="http://localhost:11434", timeout=300, ) self.vectorstore = self.load_medical_knowledge() self.retrieval_chain = RetrievalQA.from_chain_type( llm=self.llm, chain_type="stuff", retriever=self.vectorstore.as_retriever( search_kwargs={"k": 2} ), return_source_documents=True ) self.search = DuckDuckGoSearchRun() self.tools = [ Tool( name="Medical Knowledge Base", func=self.query_medical_knowledge, description="适合用来回答医学知识相关的问题,包括疾病、药物、急救和健康生活方式等内容" ), Tool( name="Web Search", func=self.search.run, description="适合用来搜索最新的医疗信息、研究进展和新闻等互联网信息" ), Tool( name="Symptom Extractor", func=self.extract_symptoms, description="适合用来从文本中提取症状信息" ), Tool( name="Severity Assessment", func=self.assess_severity, description="适合用来评估症状的严重程度" ) ] self.agent = initialize_agent( self.tools, self.llm, agent="zero-shot-react-description", verbose=False, handle_parsing_errors=True, max_iterations=3, early_stopping_method="force" ) self.medical_tools = MedicalTools() def load_medical_knowledge(self): """加载医疗知识库数据""" data_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../data") files = glob.glob(os.path.join(data_dir, "*.txt")) if not files: print("警告: 未找到医疗知识库文件。将创建一个空的向量存储。") from langchain_community.embeddings import FakeEmbeddings from langchain_core.documents import Document empty_docs = [Document(page_content="这是一个空的医疗知识库文档", metadata={"source": "empty"})] return Chroma.from_documents(empty_docs, FakeEmbeddings(size=768)) documents = [] for file in files: loader = TextLoader(file, encoding="utf-8") documents.extend(loader.load()) text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200) texts = text_splitter.split_documents(documents) print("使用FakeEmbeddings创建向量存储...") from langchain_community.embeddings import FakeEmbeddings vectorstore = Chroma.from_documents(texts, FakeEmbeddings(size=768)) return vectorstore def query_medical_knowledge(self, query): """查询医疗知识库 - 优化查询逻辑以提高响应速度""" try: result = self.retrieval_chain.invoke(query) answer = result["result"] sources = set() for doc in result["source_documents"]: if "source" in doc.metadata: sources.add(os.path.basename(doc.metadata["source"])) formatted_sources = "\n来源: " + ", ".join(sources) if sources else "" return f"{answer}{formatted_sources}" except Exception as e: print(f"知识库查询错误: {e}") return f"知识库查询失败: {str(e)}" def extract_symptoms(self, text): """提取症状信息""" symptoms = self.medical_tools.extract_symptoms(text) if symptoms: return f"提取到的症状: {', '.join(symptoms)}" else: return "未提取到明显症状" def assess_severity(self, symptoms_text): """评估症状严重程度""" symptoms = self.medical_tools.extract_symptoms(symptoms_text) if not symptoms: return "未提取到可评估的症状" severity = self.medical_tools.assess_severity(symptoms) return f"症状: {', '.join(symptoms)}\n{severity}" def run(self, question): """运行智能体回答问题""" is_valid, error_msg = self.medical_tools.validate_medical_query(question) if not is_valid: return error_msg return self.agent.invoke(question)
if __name__ == "__main__": medical_agent = MedicalAgent() questions = [ "什么是高血压?如何预防?", "老年高血压患者有哪些注意事项?", "脑血栓的高危因素有哪些?如何预防?", "糖尿病的预防措施有哪些?", "老年人如何保持健康的生活方式?", "高血压、糖尿病和脑血栓之间有什么关系?", "高危人群应该多久进行一次体检?" ] for q in questions: print(f"\n问题: {q}") result = medical_agent.run(q) print(f"回答: {result}")
|