forked from open-webui/open-webui
Merge pull request #1862 from cheahjs/feat/filter-local-rag-fetch
feat: add ENABLE_LOCAL_WEB_FETCH to protect against SSRF attacks
This commit is contained in:
commit
1afc49c1e4
4 changed files with 45 additions and 1 deletions
|
@ -31,6 +31,11 @@ from langchain_community.document_loaders import (
|
|||
)
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
import validators
|
||||
import urllib.parse
|
||||
import socket
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import mimetypes
|
||||
|
@ -84,6 +89,7 @@ from config import (
|
|||
CHUNK_SIZE,
|
||||
CHUNK_OVERLAP,
|
||||
RAG_TEMPLATE,
|
||||
ENABLE_LOCAL_WEB_FETCH,
|
||||
)
|
||||
|
||||
from constants import ERROR_MESSAGES
|
||||
|
@ -454,7 +460,7 @@ def query_collection_handler(
|
|||
def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
|
||||
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
|
||||
try:
|
||||
loader = WebBaseLoader(form_data.url)
|
||||
loader = get_web_loader(form_data.url)
|
||||
data = loader.load()
|
||||
|
||||
collection_name = form_data.collection_name
|
||||
|
@ -475,6 +481,37 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
|
|||
)
|
||||
|
||||
|
||||
def get_web_loader(url: str):
|
||||
# Check if the URL is valid
|
||||
if isinstance(validators.url(url), validators.ValidationError):
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
if not ENABLE_LOCAL_WEB_FETCH:
|
||||
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
# Get IPv4 and IPv6 addresses
|
||||
ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
|
||||
# Check if any of the resolved addresses are private
|
||||
# This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
|
||||
for ip in ipv4_addresses:
|
||||
if validators.ipv4(ip, private=True):
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
for ip in ipv6_addresses:
|
||||
if validators.ipv6(ip, private=True):
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
return WebBaseLoader(url)
|
||||
|
||||
|
||||
def resolve_hostname(hostname):
|
||||
# Get address information
|
||||
addr_info = socket.getaddrinfo(hostname, None)
|
||||
|
||||
# Extract IP addresses from address information
|
||||
ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
|
||||
ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
|
||||
|
||||
return ipv4_addresses, ipv6_addresses
|
||||
|
||||
|
||||
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
|
|
|
@ -520,6 +520,8 @@ RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE)
|
|||
RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
|
||||
RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY)
|
||||
|
||||
ENABLE_LOCAL_WEB_FETCH = os.getenv("ENABLE_LOCAL_WEB_FETCH", "False").lower() == "true"
|
||||
|
||||
####################################
|
||||
# Transcribe
|
||||
####################################
|
||||
|
|
|
@ -71,3 +71,7 @@ class ERROR_MESSAGES(str, Enum):
|
|||
EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."
|
||||
|
||||
DB_NOT_SQLITE = "This feature is only available when running with SQLite databases."
|
||||
|
||||
INVALID_URL = (
|
||||
"Oops! The URL you provided is invalid. Please double-check and try again."
|
||||
)
|
||||
|
|
|
@ -43,6 +43,7 @@ pandas
|
|||
openpyxl
|
||||
pyxlsb
|
||||
xlrd
|
||||
validators
|
||||
|
||||
opencv-python-headless
|
||||
rapidocr-onnxruntime
|
||||
|
|
Loading…
Reference in a new issue