diff --git a/backend/apps/functions/main.py b/backend/apps/functions/main.py new file mode 100644 index 00000000..1bdf6ca3 --- /dev/null +++ b/backend/apps/functions/main.py @@ -0,0 +1,172 @@ +from pathlib import Path +import ast +import builtins + + +from fastapi import ( + FastAPI, + Request, + Depends, + HTTPException, + status, + UploadFile, + File, + Form, +) + +from fastapi.middleware.cors import CORSMiddleware +from apps.functions.security import ALLOWED_MODULES, ALLOWED_BUILTINS, custom_import +from utils.utils import get_current_user, get_admin_user + + +from config import FUNCTIONS_DIR +from constants import ERROR_MESSAGES + + +from pydantic import BaseModel + +from typing import Optional + +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +@app.get("/") +async def get_status(): + return {"status": True} + + +class FunctionForm(BaseModel): + name: str + content: str + + +@app.post("/add") +def add_function( + form_data: FunctionForm, + user=Depends(get_admin_user), +): + try: + filename = f"{FUNCTIONS_DIR}/{form_data.name}.py" + if not Path(filename).exists(): + with open(filename, "w") as file: + file.write(form_data.content) + return f"{form_data.name}.py" in list( + map(lambda x: x.name, Path(FUNCTIONS_DIR).rglob("./*")) + ) + else: + raise Exception("Function already exists") + except Exception as e: + print(e) + return False + + +@app.post("/update") +def update_function( + form_data: FunctionForm, + user=Depends(get_admin_user), +): + try: + filename = f"{FUNCTIONS_DIR}/{form_data.name}.py" + if Path(filename).exists(): + with open(filename, "w") as file: + file.write(form_data.content) + return f"{form_data.name}.py" in list( + map(lambda x: x.name, Path(FUNCTIONS_DIR).rglob("./*")) + ) + else: + raise Exception("Function does not exist") + except Exception as e: + print(e) + return False + + +@app.get("/check/{function}") +def check_function( + function: str, + user=Depends(get_admin_user), +): + filename = f"{FUNCTIONS_DIR}/{function}.py" + + # Check if the function file exists + if not Path(filename).is_file(): + raise HTTPException(status_code=404, detail="Function not found") + + # Read the code from the file + with open(filename, "r") as file: + code = file.read() + + return {"name": function, "content": code} + + +@app.get("/list") +def list_functions( + user=Depends(get_admin_user), +): + files = list(map(lambda x: x.name, Path(FUNCTIONS_DIR).rglob("./*"))) + return files + + +def validate_imports(code): + try: + tree = ast.parse(code) + except SyntaxError as e: + raise HTTPException(status_code=400, detail=f"Syntax error in function: {e}") + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + module_names = [alias.name for alias in node.names] + elif isinstance(node, ast.ImportFrom): + module_names = [node.module] + else: + continue + + for name in module_names: + if name not in ALLOWED_MODULES: + raise HTTPException( + status_code=400, detail=f"Import of module {name} is not allowed" + ) + + +@app.post("/exec/{function}") +def exec_function( + function: str, + kwargs: Optional[dict] = None, + user=Depends(get_current_user), +): + filename = f"{FUNCTIONS_DIR}/{function}.py" + + # Check if the function file exists + if not Path(filename).is_file(): + raise HTTPException(status_code=404, detail="Function not found") + + # Read the code from the file + with open(filename, "r") as file: + code = file.read() + + validate_imports(code) + + try: + # Execute the code within a restricted namespace + namespace = {name: getattr(builtins, name) for name in ALLOWED_BUILTINS} + namespace["__import__"] = custom_import + exec(code, namespace) + except Exception as e: + raise HTTPException(status_code=400, detail=f"Function: {e}") + + # Check if the function exists in the namespace + if "main" not in namespace or not callable(namespace["main"]): + raise HTTPException(status_code=400, detail="Invalid function") + + try: + # Execute the function with provided kwargs + result = namespace["main"](kwargs) if kwargs else namespace["main"]() + return result + except Exception as e: + raise HTTPException(status_code=400, detail=f"Function: {e}") diff --git a/backend/apps/functions/security.py b/backend/apps/functions/security.py new file mode 100644 index 00000000..4783923f --- /dev/null +++ b/backend/apps/functions/security.py @@ -0,0 +1,140 @@ +ALLOWED_MODULES = { + "pydantic", + "math", + "json", + "time", + "datetime", + "requests", +} # Add allowed modules here + + +def custom_import(name, globals=None, locals=None, fromlist=(), level=0): + if name in ALLOWED_MODULES: + return __import__(name, globals, locals, fromlist, level) + raise ImportError(f"Import of module {name} is not allowed") + + +# Define a restricted set of builtins +ALLOWED_BUILTINS = { + "ArithmeticError", + "AssertionError", + "AttributeError", + "BaseException", + "BufferError", + "BytesWarning", + "DeprecationWarning", + "EOFError", + "Ellipsis", + "EnvironmentError", + "Exception", + "False", + "FloatingPointError", + "FutureWarning", + "GeneratorExit", + "IOError", + "ImportError", + "ImportWarning", + "IndentationError", + "IndexError", + "KeyError", + "KeyboardInterrupt", + "LookupError", + "MemoryError", + "NameError", + "None", + "NotImplemented", + "NotImplementedError", + "OSError", + "OverflowError", + "PendingDeprecationWarning", + "ReferenceError", + "RuntimeError", + "RuntimeWarning", + "StopIteration", + "SyntaxError", + "SyntaxWarning", + "SystemError", + "SystemExit", + "TabError", + "True", + "TypeError", + "UnboundLocalError", + "UnicodeDecodeError", + "UnicodeEncodeError", + "UnicodeError", + "UnicodeTranslateError", + "UnicodeWarning", + "UserWarning", + "ValueError", + "Warning", + "ZeroDivisionError", + "__build_class__", + "__debug__", + "__import__", + "abs", + "all", + "any", + "ascii", + "bin", + "bool", + "bytearray", + "bytes", + "callable", + "chr", + "classmethod", + "compile", + "complex", + "delattr", + "dict", + "dir", + "divmod", + "enumerate", + "eval", + "exec", + "filter", + "float", + "format", + "frozenset", + "getattr", + "globals", + "hasattr", + "hash", + "hex", + "id", + "input", + "int", + "isinstance", + "issubclass", + "iter", + "len", + "list", + "locals", + "map", + "max", + "memoryview", + "min", + "next", + "object", + "oct", + "open", + "ord", + "pow", + "print", + "property", + "range", + "repr", + "reversed", + "round", + "set", + "setattr", + "slice", + "sorted", + "staticmethod", + "str", + "sum", + "super", + "tuple", + "type", + "vars", + "zip", +} diff --git a/backend/config.py b/backend/config.py index 440256c4..ed38eb81 100644 --- a/backend/config.py +++ b/backend/config.py @@ -44,6 +44,13 @@ CACHE_DIR = f"{DATA_DIR}/cache" Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) +#################################### +# Functions DIR +#################################### + +FUNCTIONS_DIR = f"{DATA_DIR}/functions" +Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True) + #################################### # Docs DIR #################################### diff --git a/backend/main.py b/backend/main.py index 3a28670e..9dd8c8ae 100644 --- a/backend/main.py +++ b/backend/main.py @@ -11,7 +11,7 @@ from starlette.exceptions import HTTPException as StarletteHTTPException from apps.ollama.main import app as ollama_app from apps.openai.main import app as openai_app from apps.audio.main import app as audio_app - +from apps.functions.main import app as functions_app from apps.web.main import app as webui_app from apps.rag.main import app as rag_app @@ -58,8 +58,9 @@ app.mount("/api/v1", webui_app) app.mount("/ollama/api", ollama_app) app.mount("/openai/api", openai_app) -app.mount("/audio/api/v1", audio_app) app.mount("/rag/api/v1", rag_app) +app.mount("/audio/api/v1", audio_app) +app.mount("/functions/api/v1", functions_app) app.mount(