forked from open-webui/open-webui
		
	Merge pull request #1862 from cheahjs/feat/filter-local-rag-fetch
feat: add ENABLE_LOCAL_WEB_FETCH to protect against SSRF attacks
This commit is contained in:
		
						commit
						1afc49c1e4
					
				
					 4 changed files with 45 additions and 1 deletions
				
			
		|  | @ -31,6 +31,11 @@ from langchain_community.document_loaders import ( | |||
| ) | ||||
| from langchain.text_splitter import RecursiveCharacterTextSplitter | ||||
| 
 | ||||
| import validators | ||||
| import urllib.parse | ||||
| import socket | ||||
| 
 | ||||
| 
 | ||||
| from pydantic import BaseModel | ||||
| from typing import Optional | ||||
| import mimetypes | ||||
|  | @ -84,6 +89,7 @@ from config import ( | |||
|     CHUNK_SIZE, | ||||
|     CHUNK_OVERLAP, | ||||
|     RAG_TEMPLATE, | ||||
|     ENABLE_LOCAL_WEB_FETCH, | ||||
| ) | ||||
| 
 | ||||
| from constants import ERROR_MESSAGES | ||||
|  | @ -454,7 +460,7 @@ def query_collection_handler( | |||
| def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): | ||||
|     # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" | ||||
|     try: | ||||
|         loader = WebBaseLoader(form_data.url) | ||||
|         loader = get_web_loader(form_data.url) | ||||
|         data = loader.load() | ||||
| 
 | ||||
|         collection_name = form_data.collection_name | ||||
|  | @ -475,6 +481,37 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): | |||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| def get_web_loader(url: str): | ||||
|     # Check if the URL is valid | ||||
|     if isinstance(validators.url(url), validators.ValidationError): | ||||
|         raise ValueError(ERROR_MESSAGES.INVALID_URL) | ||||
|     if not ENABLE_LOCAL_WEB_FETCH: | ||||
|         # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses | ||||
|         parsed_url = urllib.parse.urlparse(url) | ||||
|         # Get IPv4 and IPv6 addresses | ||||
|         ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname) | ||||
|         # Check if any of the resolved addresses are private | ||||
|         # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader | ||||
|         for ip in ipv4_addresses: | ||||
|             if validators.ipv4(ip, private=True): | ||||
|                 raise ValueError(ERROR_MESSAGES.INVALID_URL) | ||||
|         for ip in ipv6_addresses: | ||||
|             if validators.ipv6(ip, private=True): | ||||
|                 raise ValueError(ERROR_MESSAGES.INVALID_URL) | ||||
|     return WebBaseLoader(url) | ||||
| 
 | ||||
| 
 | ||||
| def resolve_hostname(hostname): | ||||
|     # Get address information | ||||
|     addr_info = socket.getaddrinfo(hostname, None) | ||||
| 
 | ||||
|     # Extract IP addresses from address information | ||||
|     ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET] | ||||
|     ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6] | ||||
| 
 | ||||
|     return ipv4_addresses, ipv6_addresses | ||||
| 
 | ||||
| 
 | ||||
| def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: | ||||
| 
 | ||||
|     text_splitter = RecursiveCharacterTextSplitter( | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy Jaeryang Baek
						Timothy Jaeryang Baek