1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
| import streamlit as st import os from pathlib import Path from components.document_loader import DocumentLoader from components.vector_store import VectorStoreFactory from components.retriever import AdvancedRetriever from components.generator import RAGGenerator from config.settings import settings
st.set_page_config( page_title="智能文档问答系统", page_icon="🤖", layout="wide", initial_sidebar_state="expanded" )
if 'vector_store' not in st.session_state: st.session_state.vector_store = None if 'rag_generator' not in st.session_state: st.session_state.rag_generator = None if 'chat_history' not in st.session_state: st.session_state.chat_history = []
with st.sidebar: st.header("📁 文档管理") uploaded_files = st.file_uploader( "上传文档", type=["txt", "pdf", "docx", "csv"], accept_multiple_files=True, help="支持TXT、PDF、DOCX、CSV格式" ) if uploaded_files and st.button("处理文档", type="primary"): with st.spinner("正在处理文档..."): docs_dir = Path("./temp_docs") docs_dir.mkdir(exist_ok=True) saved_files = [] for uploaded_file in uploaded_files: file_path = docs_dir / uploaded_file.name with open(file_path, "wb") as f: f.write(uploaded_file.getbuffer()) saved_files.append(file_path) loader = DocumentLoader( chunk_size=settings.chunk_size, chunk_overlap=settings.chunk_overlap ) all_documents = [] for file_path in saved_files: try: documents = loader.load_file(file_path) all_documents.extend(documents) except Exception as e: st.error(f"处理文件 {file_path.name} 时出错: {e}") if all_documents: vector_store = VectorStoreFactory.create_vector_store( store_type=settings.vector_db_type, collection_name="uploaded_docs" ) vector_store.add_documents(all_documents) retriever = AdvancedRetriever(vector_store) rag_generator = RAGGenerator(retriever) st.session_state.vector_store = vector_store st.session_state.rag_generator = rag_generator st.success(f"成功处理 {len(all_documents)} 个文档片段!") for file_path in saved_files: file_path.unlink() st.header("⚙️ 系统配置") top_k = st.slider("检索文档数量", 1, 10, settings.top_k) temperature = st.slider("生成温度", 0.0, 1.0, settings.temperature, 0.1) max_tokens = st.number_input("最大Token数", 100, 2000, settings.max_tokens) if st.button("清除对话历史"): st.session_state.chat_history = [] st.rerun()
st.title("🤖 智能文档问答系统") st.markdown("基于RAG技术的智能文档问答,支持多种文档格式")
if st.session_state.rag_generator is None: st.info("👈 请先在侧边栏上传文档") with st.expander("💡 使用说明", expanded=True): st.markdown(""" ### 如何使用: 1. **上传文档**:在左侧侧边栏上传您的文档文件 2. **处理文档**:点击"处理文档"按钮,系统会自动分析和索引文档内容 3. **开始问答**:在下方输入框中输入您的问题 4. **查看答案**:系统会基于文档内容生成准确的答案 ### 支持的文档格式: - 📄 TXT文本文件 - 📕 PDF文档 - 📘 Word文档(DOCX) - 📊 CSV数据文件 ### 功能特点: - 🔍 智能检索:基于语义相似度检索相关内容 - 🧠 上下文理解:结合多个文档片段生成综合答案 - 📚 来源追踪:显示答案的具体来源文档 - 💬 对话记忆:支持多轮对话上下文 """) else: st.header("💬 智能问答") for i, (question, answer, sources) in enumerate(st.session_state.chat_history): with st.container(): st.markdown(f"**🙋 问题 {i+1}:** {question}") st.markdown(f"**🤖 回答:** {answer}") if sources: with st.expander(f"📚 参考来源 ({len(sources)}个文档片段)"): for j, source in enumerate(sources): st.markdown(f"**片段 {j+1}:**") st.text(source.page_content[:300] + "..." if len(source.page_content) > 300 else source.page_content) if hasattr(source, 'metadata') and source.metadata: st.caption(f"来源:{source.metadata}") st.markdown("---") question = st.text_input( "请输入您的问题:", placeholder="例如:文档中提到了哪些关键技术?", key="question_input" ) col1, col2 = st.columns([1, 4]) with col1: ask_button = st.button("🚀 提问", type="primary", use_container_width=True) with col2: stream_mode = st.checkbox("流式输出", value=False) if ask_button and question: with st.spinner("正在思考中..."): try: if stream_mode: answer_placeholder = st.empty() answer_text = "" for chunk in st.session_state.rag_generator.generate_streaming_answer(question): answer_text += chunk answer_placeholder.markdown(f"**🤖 回答:** {answer_text}") docs = st.session_state.rag_generator.retriever.retrieve(question, k=top_k) result = { "answer": answer_text, "source_documents": docs, "question": question } else: result = st.session_state.rag_generator.generate_answer(question) st.session_state.chat_history.append(( result["question"], result["answer"], result["source_documents"] )) st.rerun() except Exception as e: st.error(f"生成答案时出错:{e}") if st.session_state.chat_history: st.header("📊 对话统计") col1, col2, col3 = st.columns(3) with col1: st.metric("对话轮数", len(st.session_state.chat_history)) with col2: total_chars = sum(len(answer) for _, answer, _ in st.session_state.chat_history) st.metric("总回答字数", total_chars) with col3: avg_sources = sum(len(sources) for _, _, sources in st.session_state.chat_history) / len(st.session_state.chat_history) st.metric("平均参考来源", f"{avg_sources:.1f}")
st.markdown("---") st.markdown( "<div style='text-align: center; color: gray;'>" "🤖 智能文档问答系统 | 基于LangChain和Streamlit构建" "</div>", unsafe_allow_html=True )
|