forked from open-webui/open-webui
Compare commits
3 commits
Author | SHA1 | Date | |
---|---|---|---|
|
59de980306 | ||
|
837feb4e79 | ||
|
6caa7750bb |
4 changed files with 321 additions and 1 deletions
172
backend/apps/functions/main.py
Normal file
172
backend/apps/functions/main.py
Normal file
|
@ -0,0 +1,172 @@
|
|||
from pathlib import Path
|
||||
import ast
|
||||
import builtins
|
||||
|
||||
|
||||
from fastapi import (
|
||||
FastAPI,
|
||||
Request,
|
||||
Depends,
|
||||
HTTPException,
|
||||
status,
|
||||
UploadFile,
|
||||
File,
|
||||
Form,
|
||||
)
|
||||
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from apps.functions.security import ALLOWED_MODULES, ALLOWED_BUILTINS, custom_import
|
||||
from utils.utils import get_current_user, get_admin_user
|
||||
|
||||
|
||||
from config import FUNCTIONS_DIR
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from typing import Optional
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def get_status():
|
||||
return {"status": True}
|
||||
|
||||
|
||||
class FunctionForm(BaseModel):
|
||||
name: str
|
||||
content: str
|
||||
|
||||
|
||||
@app.post("/add")
|
||||
def add_function(
|
||||
form_data: FunctionForm,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
try:
|
||||
filename = f"{FUNCTIONS_DIR}/{form_data.name}.py"
|
||||
if not Path(filename).exists():
|
||||
with open(filename, "w") as file:
|
||||
file.write(form_data.content)
|
||||
return f"{form_data.name}.py" in list(
|
||||
map(lambda x: x.name, Path(FUNCTIONS_DIR).rglob("./*"))
|
||||
)
|
||||
else:
|
||||
raise Exception("Function already exists")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return False
|
||||
|
||||
|
||||
@app.post("/update")
|
||||
def update_function(
|
||||
form_data: FunctionForm,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
try:
|
||||
filename = f"{FUNCTIONS_DIR}/{form_data.name}.py"
|
||||
if Path(filename).exists():
|
||||
with open(filename, "w") as file:
|
||||
file.write(form_data.content)
|
||||
return f"{form_data.name}.py" in list(
|
||||
map(lambda x: x.name, Path(FUNCTIONS_DIR).rglob("./*"))
|
||||
)
|
||||
else:
|
||||
raise Exception("Function does not exist")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return False
|
||||
|
||||
|
||||
@app.get("/check/{function}")
|
||||
def check_function(
|
||||
function: str,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
filename = f"{FUNCTIONS_DIR}/{function}.py"
|
||||
|
||||
# Check if the function file exists
|
||||
if not Path(filename).is_file():
|
||||
raise HTTPException(status_code=404, detail="Function not found")
|
||||
|
||||
# Read the code from the file
|
||||
with open(filename, "r") as file:
|
||||
code = file.read()
|
||||
|
||||
return {"name": function, "content": code}
|
||||
|
||||
|
||||
@app.get("/list")
|
||||
def list_functions(
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
files = list(map(lambda x: x.name, Path(FUNCTIONS_DIR).rglob("./*")))
|
||||
return files
|
||||
|
||||
|
||||
def validate_imports(code):
|
||||
try:
|
||||
tree = ast.parse(code)
|
||||
except SyntaxError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Syntax error in function: {e}")
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Import):
|
||||
module_names = [alias.name for alias in node.names]
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
module_names = [node.module]
|
||||
else:
|
||||
continue
|
||||
|
||||
for name in module_names:
|
||||
if name not in ALLOWED_MODULES:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Import of module {name} is not allowed"
|
||||
)
|
||||
|
||||
|
||||
@app.post("/exec/{function}")
|
||||
def exec_function(
|
||||
function: str,
|
||||
kwargs: Optional[dict] = None,
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
filename = f"{FUNCTIONS_DIR}/{function}.py"
|
||||
|
||||
# Check if the function file exists
|
||||
if not Path(filename).is_file():
|
||||
raise HTTPException(status_code=404, detail="Function not found")
|
||||
|
||||
# Read the code from the file
|
||||
with open(filename, "r") as file:
|
||||
code = file.read()
|
||||
|
||||
validate_imports(code)
|
||||
|
||||
try:
|
||||
# Execute the code within a restricted namespace
|
||||
namespace = {name: getattr(builtins, name) for name in ALLOWED_BUILTINS}
|
||||
namespace["__import__"] = custom_import
|
||||
exec(code, namespace)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Function: {e}")
|
||||
|
||||
# Check if the function exists in the namespace
|
||||
if "main" not in namespace or not callable(namespace["main"]):
|
||||
raise HTTPException(status_code=400, detail="Invalid function")
|
||||
|
||||
try:
|
||||
# Execute the function with provided kwargs
|
||||
result = namespace["main"](kwargs) if kwargs else namespace["main"]()
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Function: {e}")
|
140
backend/apps/functions/security.py
Normal file
140
backend/apps/functions/security.py
Normal file
|
@ -0,0 +1,140 @@
|
|||
ALLOWED_MODULES = {
|
||||
"pydantic",
|
||||
"math",
|
||||
"json",
|
||||
"time",
|
||||
"datetime",
|
||||
"requests",
|
||||
} # Add allowed modules here
|
||||
|
||||
|
||||
def custom_import(name, globals=None, locals=None, fromlist=(), level=0):
|
||||
if name in ALLOWED_MODULES:
|
||||
return __import__(name, globals, locals, fromlist, level)
|
||||
raise ImportError(f"Import of module {name} is not allowed")
|
||||
|
||||
|
||||
# Define a restricted set of builtins
|
||||
ALLOWED_BUILTINS = {
|
||||
"ArithmeticError",
|
||||
"AssertionError",
|
||||
"AttributeError",
|
||||
"BaseException",
|
||||
"BufferError",
|
||||
"BytesWarning",
|
||||
"DeprecationWarning",
|
||||
"EOFError",
|
||||
"Ellipsis",
|
||||
"EnvironmentError",
|
||||
"Exception",
|
||||
"False",
|
||||
"FloatingPointError",
|
||||
"FutureWarning",
|
||||
"GeneratorExit",
|
||||
"IOError",
|
||||
"ImportError",
|
||||
"ImportWarning",
|
||||
"IndentationError",
|
||||
"IndexError",
|
||||
"KeyError",
|
||||
"KeyboardInterrupt",
|
||||
"LookupError",
|
||||
"MemoryError",
|
||||
"NameError",
|
||||
"None",
|
||||
"NotImplemented",
|
||||
"NotImplementedError",
|
||||
"OSError",
|
||||
"OverflowError",
|
||||
"PendingDeprecationWarning",
|
||||
"ReferenceError",
|
||||
"RuntimeError",
|
||||
"RuntimeWarning",
|
||||
"StopIteration",
|
||||
"SyntaxError",
|
||||
"SyntaxWarning",
|
||||
"SystemError",
|
||||
"SystemExit",
|
||||
"TabError",
|
||||
"True",
|
||||
"TypeError",
|
||||
"UnboundLocalError",
|
||||
"UnicodeDecodeError",
|
||||
"UnicodeEncodeError",
|
||||
"UnicodeError",
|
||||
"UnicodeTranslateError",
|
||||
"UnicodeWarning",
|
||||
"UserWarning",
|
||||
"ValueError",
|
||||
"Warning",
|
||||
"ZeroDivisionError",
|
||||
"__build_class__",
|
||||
"__debug__",
|
||||
"__import__",
|
||||
"abs",
|
||||
"all",
|
||||
"any",
|
||||
"ascii",
|
||||
"bin",
|
||||
"bool",
|
||||
"bytearray",
|
||||
"bytes",
|
||||
"callable",
|
||||
"chr",
|
||||
"classmethod",
|
||||
"compile",
|
||||
"complex",
|
||||
"delattr",
|
||||
"dict",
|
||||
"dir",
|
||||
"divmod",
|
||||
"enumerate",
|
||||
"eval",
|
||||
"exec",
|
||||
"filter",
|
||||
"float",
|
||||
"format",
|
||||
"frozenset",
|
||||
"getattr",
|
||||
"globals",
|
||||
"hasattr",
|
||||
"hash",
|
||||
"hex",
|
||||
"id",
|
||||
"input",
|
||||
"int",
|
||||
"isinstance",
|
||||
"issubclass",
|
||||
"iter",
|
||||
"len",
|
||||
"list",
|
||||
"locals",
|
||||
"map",
|
||||
"max",
|
||||
"memoryview",
|
||||
"min",
|
||||
"next",
|
||||
"object",
|
||||
"oct",
|
||||
"open",
|
||||
"ord",
|
||||
"pow",
|
||||
"print",
|
||||
"property",
|
||||
"range",
|
||||
"repr",
|
||||
"reversed",
|
||||
"round",
|
||||
"set",
|
||||
"setattr",
|
||||
"slice",
|
||||
"sorted",
|
||||
"staticmethod",
|
||||
"str",
|
||||
"sum",
|
||||
"super",
|
||||
"tuple",
|
||||
"type",
|
||||
"vars",
|
||||
"zip",
|
||||
}
|
|
@ -123,6 +123,13 @@ CACHE_DIR = f"{DATA_DIR}/cache"
|
|||
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
####################################
|
||||
# Functions DIR
|
||||
####################################
|
||||
|
||||
FUNCTIONS_DIR = f"{DATA_DIR}/functions"
|
||||
Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
####################################
|
||||
# Docs DIR
|
||||
####################################
|
||||
|
|
|
@ -15,6 +15,7 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
|
|||
from apps.ollama.main import app as ollama_app
|
||||
from apps.openai.main import app as openai_app
|
||||
from apps.audio.main import app as audio_app
|
||||
from apps.functions.main import app as functions_app
|
||||
from apps.images.main import app as images_app
|
||||
from apps.rag.main import app as rag_app
|
||||
|
||||
|
@ -61,10 +62,10 @@ app.mount("/api/v1", webui_app)
|
|||
|
||||
app.mount("/ollama/api", ollama_app)
|
||||
app.mount("/openai/api", openai_app)
|
||||
|
||||
app.mount("/images/api/v1", images_app)
|
||||
app.mount("/audio/api/v1", audio_app)
|
||||
app.mount("/rag/api/v1", rag_app)
|
||||
app.mount("/functions/api/v1", functions_app)
|
||||
|
||||
|
||||
@app.get("/api/config")
|
||||
|
|
Loading…
Reference in a new issue