forked from open-webui/open-webui
Merge pull request #1862 from cheahjs/feat/filter-local-rag-fetch
feat: add ENABLE_LOCAL_WEB_FETCH to protect against SSRF attacks
This commit is contained in:
commit
1afc49c1e4
4 changed files with 45 additions and 1 deletions
|
@ -31,6 +31,11 @@ from langchain_community.document_loaders import (
|
||||||
)
|
)
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
|
|
||||||
|
import validators
|
||||||
|
import urllib.parse
|
||||||
|
import socket
|
||||||
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import mimetypes
|
import mimetypes
|
||||||
|
@ -84,6 +89,7 @@ from config import (
|
||||||
CHUNK_SIZE,
|
CHUNK_SIZE,
|
||||||
CHUNK_OVERLAP,
|
CHUNK_OVERLAP,
|
||||||
RAG_TEMPLATE,
|
RAG_TEMPLATE,
|
||||||
|
ENABLE_LOCAL_WEB_FETCH,
|
||||||
)
|
)
|
||||||
|
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
|
@ -454,7 +460,7 @@ def query_collection_handler(
|
||||||
def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
|
def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
|
||||||
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
|
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
|
||||||
try:
|
try:
|
||||||
loader = WebBaseLoader(form_data.url)
|
loader = get_web_loader(form_data.url)
|
||||||
data = loader.load()
|
data = loader.load()
|
||||||
|
|
||||||
collection_name = form_data.collection_name
|
collection_name = form_data.collection_name
|
||||||
|
@ -475,6 +481,37 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_web_loader(url: str):
|
||||||
|
# Check if the URL is valid
|
||||||
|
if isinstance(validators.url(url), validators.ValidationError):
|
||||||
|
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||||
|
if not ENABLE_LOCAL_WEB_FETCH:
|
||||||
|
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
|
||||||
|
parsed_url = urllib.parse.urlparse(url)
|
||||||
|
# Get IPv4 and IPv6 addresses
|
||||||
|
ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
|
||||||
|
# Check if any of the resolved addresses are private
|
||||||
|
# This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
|
||||||
|
for ip in ipv4_addresses:
|
||||||
|
if validators.ipv4(ip, private=True):
|
||||||
|
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||||
|
for ip in ipv6_addresses:
|
||||||
|
if validators.ipv6(ip, private=True):
|
||||||
|
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||||
|
return WebBaseLoader(url)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_hostname(hostname):
|
||||||
|
# Get address information
|
||||||
|
addr_info = socket.getaddrinfo(hostname, None)
|
||||||
|
|
||||||
|
# Extract IP addresses from address information
|
||||||
|
ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
|
||||||
|
ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
|
||||||
|
|
||||||
|
return ipv4_addresses, ipv6_addresses
|
||||||
|
|
||||||
|
|
||||||
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
|
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
|
||||||
|
|
||||||
text_splitter = RecursiveCharacterTextSplitter(
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
|
|
@ -520,6 +520,8 @@ RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE)
|
||||||
RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
|
RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
|
||||||
RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY)
|
RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY)
|
||||||
|
|
||||||
|
ENABLE_LOCAL_WEB_FETCH = os.getenv("ENABLE_LOCAL_WEB_FETCH", "False").lower() == "true"
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# Transcribe
|
# Transcribe
|
||||||
####################################
|
####################################
|
||||||
|
|
|
@ -71,3 +71,7 @@ class ERROR_MESSAGES(str, Enum):
|
||||||
EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."
|
EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."
|
||||||
|
|
||||||
DB_NOT_SQLITE = "This feature is only available when running with SQLite databases."
|
DB_NOT_SQLITE = "This feature is only available when running with SQLite databases."
|
||||||
|
|
||||||
|
INVALID_URL = (
|
||||||
|
"Oops! The URL you provided is invalid. Please double-check and try again."
|
||||||
|
)
|
||||||
|
|
|
@ -43,6 +43,7 @@ pandas
|
||||||
openpyxl
|
openpyxl
|
||||||
pyxlsb
|
pyxlsb
|
||||||
xlrd
|
xlrd
|
||||||
|
validators
|
||||||
|
|
||||||
opencv-python-headless
|
opencv-python-headless
|
||||||
rapidocr-onnxruntime
|
rapidocr-onnxruntime
|
||||||
|
|
Loading…
Reference in a new issue