From 1c4e63f71eff10b79021b81244d4523e893ce767 Mon Sep 17 00:00:00 2001 From: Jun Siang Cheah Date: Mon, 29 Apr 2024 20:55:17 +0100 Subject: [PATCH] feat: add ENABLE_LOCAL_WEB_FETCH to protect against SSRF attacks --- backend/apps/rag/main.py | 39 ++++++++++++++++++++++++++++++++++++++- backend/config.py | 2 ++ backend/constants.py | 4 ++++ backend/requirements.txt | 1 + 4 files changed, 45 insertions(+), 1 deletion(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index a33a2965..98e11c8a 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -31,6 +31,11 @@ from langchain_community.document_loaders import ( ) from langchain.text_splitter import RecursiveCharacterTextSplitter +import validators +import urllib.parse +import socket + + from pydantic import BaseModel from typing import Optional import mimetypes @@ -84,6 +89,7 @@ from config import ( CHUNK_SIZE, CHUNK_OVERLAP, RAG_TEMPLATE, + ENABLE_LOCAL_WEB_FETCH, ) from constants import ERROR_MESSAGES @@ -454,7 +460,7 @@ def query_collection_handler( def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" try: - loader = WebBaseLoader(form_data.url) + loader = get_web_loader(form_data.url) data = loader.load() collection_name = form_data.collection_name @@ -475,6 +481,37 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): ) +def get_web_loader(url: str): + # Check if the URL is valid + if isinstance(validators.url(url), validators.ValidationError): + raise ValueError(ERROR_MESSAGES.INVALID_URL) + if not ENABLE_LOCAL_WEB_FETCH: + # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses + parsed_url = urllib.parse.urlparse(url) + # Get IPv4 and IPv6 addresses + ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname) + # Check if any of the resolved addresses are private + # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader + for ip in ipv4_addresses: + if validators.ipv4(ip, private=True): + raise ValueError(ERROR_MESSAGES.INVALID_URL) + for ip in ipv6_addresses: + if validators.ipv6(ip, private=True): + raise ValueError(ERROR_MESSAGES.INVALID_URL) + return WebBaseLoader(url) + + +def resolve_hostname(hostname): + # Get address information + addr_info = socket.getaddrinfo(hostname, None) + + # Extract IP addresses from address information + ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET] + ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6] + + return ipv4_addresses, ipv6_addresses + + def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: text_splitter = RecursiveCharacterTextSplitter( diff --git a/backend/config.py b/backend/config.py index f8dbc4d2..09880b12 100644 --- a/backend/config.py +++ b/backend/config.py @@ -520,6 +520,8 @@ RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE) RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL) RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY) +ENABLE_LOCAL_WEB_FETCH = os.getenv("ENABLE_LOCAL_WEB_FETCH", "False").lower() == "true" + #################################### # Transcribe #################################### diff --git a/backend/constants.py b/backend/constants.py index a2694575..3fdf506f 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -71,3 +71,7 @@ class ERROR_MESSAGES(str, Enum): EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding." DB_NOT_SQLITE = "This feature is only available when running with SQLite databases." + + INVALID_URL = ( + "Oops! The URL you provided is invalid. Please double-check and try again." + ) diff --git a/backend/requirements.txt b/backend/requirements.txt index 4ae93ca0..eb509c6e 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -43,6 +43,7 @@ pandas openpyxl pyxlsb xlrd +validators opencv-python-headless rapidocr-onnxruntime