feat: add ENABLE_LOCAL_WEB_FETCH to protect against SSRF attacks

This commit is contained in:
Jun Siang Cheah 2024-04-29 20:55:17 +01:00
parent e8abaa8bea
commit 1c4e63f71e
4 changed files with 45 additions and 1 deletions

View file

@ -31,6 +31,11 @@ from langchain_community.document_loaders import (
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
import validators
import urllib.parse
import socket
from pydantic import BaseModel
from typing import Optional
import mimetypes
@ -84,6 +89,7 @@ from config import (
CHUNK_SIZE,
CHUNK_OVERLAP,
RAG_TEMPLATE,
ENABLE_LOCAL_WEB_FETCH,
)
from constants import ERROR_MESSAGES
@ -454,7 +460,7 @@ def query_collection_handler(
def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try:
loader = WebBaseLoader(form_data.url)
loader = get_web_loader(form_data.url)
data = loader.load()
collection_name = form_data.collection_name
@ -475,6 +481,37 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
)
def get_web_loader(url: str):
# Check if the URL is valid
if isinstance(validators.url(url), validators.ValidationError):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
if not ENABLE_LOCAL_WEB_FETCH:
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
parsed_url = urllib.parse.urlparse(url)
# Get IPv4 and IPv6 addresses
ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
# Check if any of the resolved addresses are private
# This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
for ip in ipv4_addresses:
if validators.ipv4(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
for ip in ipv6_addresses:
if validators.ipv6(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return WebBaseLoader(url)
def resolve_hostname(hostname):
# Get address information
addr_info = socket.getaddrinfo(hostname, None)
# Extract IP addresses from address information
ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
return ipv4_addresses, ipv6_addresses
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
text_splitter = RecursiveCharacterTextSplitter(

View file

@ -520,6 +520,8 @@ RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE)
RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY)
ENABLE_LOCAL_WEB_FETCH = os.getenv("ENABLE_LOCAL_WEB_FETCH", "False").lower() == "true"
####################################
# Transcribe
####################################

View file

@ -71,3 +71,7 @@ class ERROR_MESSAGES(str, Enum):
EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."
DB_NOT_SQLITE = "This feature is only available when running with SQLite databases."
INVALID_URL = (
"Oops! The URL you provided is invalid. Please double-check and try again."
)

View file

@ -43,6 +43,7 @@ pandas
openpyxl
pyxlsb
xlrd
validators
opencv-python-headless
rapidocr-onnxruntime