feat: add ENABLE_LOCAL_WEB_FETCH to protect against SSRF attacks

This commit is contained in:
Jun Siang Cheah 2024-04-29 20:55:17 +01:00
parent e8abaa8bea
commit 1c4e63f71e
4 changed files with 45 additions and 1 deletions

View file

@ -31,6 +31,11 @@ from langchain_community.document_loaders import (
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
import validators
import urllib.parse
import socket
from pydantic import BaseModel
from typing import Optional
import mimetypes
@ -84,6 +89,7 @@ from config import (
CHUNK_SIZE,
CHUNK_OVERLAP,
RAG_TEMPLATE,
ENABLE_LOCAL_WEB_FETCH,
)
from constants import ERROR_MESSAGES
@ -454,7 +460,7 @@ def query_collection_handler(
def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try:
loader = WebBaseLoader(form_data.url)
loader = get_web_loader(form_data.url)
data = loader.load()
collection_name = form_data.collection_name
@ -475,6 +481,37 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
)
def get_web_loader(url: str):
# Check if the URL is valid
if isinstance(validators.url(url), validators.ValidationError):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
if not ENABLE_LOCAL_WEB_FETCH:
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
parsed_url = urllib.parse.urlparse(url)
# Get IPv4 and IPv6 addresses
ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
# Check if any of the resolved addresses are private
# This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
for ip in ipv4_addresses:
if validators.ipv4(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
for ip in ipv6_addresses:
if validators.ipv6(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return WebBaseLoader(url)
def resolve_hostname(hostname):
# Get address information
addr_info = socket.getaddrinfo(hostname, None)
# Extract IP addresses from address information
ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
return ipv4_addresses, ipv6_addresses
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
text_splitter = RecursiveCharacterTextSplitter(