forked from open-webui/open-webui
		
	Merge pull request #1862 from cheahjs/feat/filter-local-rag-fetch
feat: add ENABLE_LOCAL_WEB_FETCH to protect against SSRF attacks
This commit is contained in:
		
						commit
						1afc49c1e4
					
				
					 4 changed files with 45 additions and 1 deletions
				
			
		|  | @ -31,6 +31,11 @@ from langchain_community.document_loaders import ( | |||
| ) | ||||
| from langchain.text_splitter import RecursiveCharacterTextSplitter | ||||
| 
 | ||||
| import validators | ||||
| import urllib.parse | ||||
| import socket | ||||
| 
 | ||||
| 
 | ||||
| from pydantic import BaseModel | ||||
| from typing import Optional | ||||
| import mimetypes | ||||
|  | @ -84,6 +89,7 @@ from config import ( | |||
|     CHUNK_SIZE, | ||||
|     CHUNK_OVERLAP, | ||||
|     RAG_TEMPLATE, | ||||
|     ENABLE_LOCAL_WEB_FETCH, | ||||
| ) | ||||
| 
 | ||||
| from constants import ERROR_MESSAGES | ||||
|  | @ -454,7 +460,7 @@ def query_collection_handler( | |||
| def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): | ||||
|     # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" | ||||
|     try: | ||||
|         loader = WebBaseLoader(form_data.url) | ||||
|         loader = get_web_loader(form_data.url) | ||||
|         data = loader.load() | ||||
| 
 | ||||
|         collection_name = form_data.collection_name | ||||
|  | @ -475,6 +481,37 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): | |||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| def get_web_loader(url: str): | ||||
|     # Check if the URL is valid | ||||
|     if isinstance(validators.url(url), validators.ValidationError): | ||||
|         raise ValueError(ERROR_MESSAGES.INVALID_URL) | ||||
|     if not ENABLE_LOCAL_WEB_FETCH: | ||||
|         # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses | ||||
|         parsed_url = urllib.parse.urlparse(url) | ||||
|         # Get IPv4 and IPv6 addresses | ||||
|         ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname) | ||||
|         # Check if any of the resolved addresses are private | ||||
|         # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader | ||||
|         for ip in ipv4_addresses: | ||||
|             if validators.ipv4(ip, private=True): | ||||
|                 raise ValueError(ERROR_MESSAGES.INVALID_URL) | ||||
|         for ip in ipv6_addresses: | ||||
|             if validators.ipv6(ip, private=True): | ||||
|                 raise ValueError(ERROR_MESSAGES.INVALID_URL) | ||||
|     return WebBaseLoader(url) | ||||
| 
 | ||||
| 
 | ||||
| def resolve_hostname(hostname): | ||||
|     # Get address information | ||||
|     addr_info = socket.getaddrinfo(hostname, None) | ||||
| 
 | ||||
|     # Extract IP addresses from address information | ||||
|     ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET] | ||||
|     ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6] | ||||
| 
 | ||||
|     return ipv4_addresses, ipv6_addresses | ||||
| 
 | ||||
| 
 | ||||
| def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: | ||||
| 
 | ||||
|     text_splitter = RecursiveCharacterTextSplitter( | ||||
|  |  | |||
|  | @ -520,6 +520,8 @@ RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE) | |||
| RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL) | ||||
| RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY) | ||||
| 
 | ||||
| ENABLE_LOCAL_WEB_FETCH = os.getenv("ENABLE_LOCAL_WEB_FETCH", "False").lower() == "true" | ||||
| 
 | ||||
| #################################### | ||||
| # Transcribe | ||||
| #################################### | ||||
|  |  | |||
|  | @ -71,3 +71,7 @@ class ERROR_MESSAGES(str, Enum): | |||
|     EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding." | ||||
| 
 | ||||
|     DB_NOT_SQLITE = "This feature is only available when running with SQLite databases." | ||||
| 
 | ||||
|     INVALID_URL = ( | ||||
|         "Oops! The URL you provided is invalid. Please double-check and try again." | ||||
|     ) | ||||
|  |  | |||
|  | @ -43,6 +43,7 @@ pandas | |||
| openpyxl | ||||
| pyxlsb | ||||
| xlrd | ||||
| validators | ||||
| 
 | ||||
| opencv-python-headless | ||||
| rapidocr-onnxruntime | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy Jaeryang Baek
						Timothy Jaeryang Baek