Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor pipeline downloads #332

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os


####################################
# Load .env file
####################################
Expand All @@ -12,4 +13,12 @@
print("dotenv not installed, skipping...")

API_KEY = os.getenv("PIPELINES_API_KEY", "0p3n-w3bu!")

LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")

PIPELINES_DIR = os.getenv("PIPELINES_DIR", "./pipelines")
RESET_PIPELINES_DIR = os.getenv("RESET_PIPELINES_DIR", "false").lower() == "true"
PIPELINES_REQUIREMENTS_PATH = os.getenv("PIPELINES_REQUIREMENTS_PATH")
PIPELINES_URLS = os.getenv("PIPELINES_URLS")

SUPPRESS_PIP_OUTPUT = os.getenv("SUPPRESS_PIP_OUTPUT", "true").lower() == "true"
156 changes: 57 additions & 99 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,43 @@
from fastapi.middleware.cors import CORSMiddleware
from fastapi.concurrency import run_in_threadpool


from starlette.responses import StreamingResponse, Response
from pydantic import BaseModel, ConfigDict
from typing import List, Union, Generator, Iterator


from utils.pipelines.auth import bearer_security, get_current_user
from utils.pipelines.main import get_last_user_message, stream_message_template
from utils.pipelines.misc import convert_to_raw_url
from pydantic import BaseModel
from starlette.responses import StreamingResponse
from typing import Generator, Iterator

from contextlib import asynccontextmanager
from concurrent.futures import ThreadPoolExecutor
from schemas import FilterForm, OpenAIChatCompletionForm
from urllib.parse import urlparse

import shutil
import aiohttp
import os
import importlib.util
import logging
import time
import json
import uuid
import sys
import subprocess

from schemas import FilterForm, OpenAIChatCompletionForm

from config import API_KEY, PIPELINES_DIR
from utils.pipelines.logger import setup_logger

if not os.path.exists(PIPELINES_DIR):
os.makedirs(PIPELINES_DIR)
logger = setup_logger(__name__)

from utils.pipelines.auth import bearer_security, get_current_user
from utils.pipelines.downloads import (
download_file_to_folder,
download_pipelines,
install_requirements,
install_requirements_from_file,
reset_pipelines_dir,
)
from utils.pipelines.main import get_last_user_message, stream_message_template
from utils.pipelines.misc import convert_to_raw_url

from config import (
API_KEY,
PIPELINES_DIR,
PIPELINES_REQUIREMENTS_PATH,
PIPELINES_URLS,
RESET_PIPELINES_DIR,
)


PIPELINES = {}
Expand Down Expand Up @@ -106,62 +112,28 @@ def get_all_pipelines():

return pipelines

def parse_frontmatter(content):
frontmatter = {}
for line in content.split('\n'):
if ':' in line:
key, value = line.split(':', 1)
frontmatter[key.strip().lower()] = value.strip()
return frontmatter

def install_frontmatter_requirements(requirements):
if requirements:
req_list = [req.strip() for req in requirements.split(',')]
for req in req_list:
print(f"Installing requirement: {req}")
subprocess.check_call([sys.executable, "-m", "pip", "install", req])
else:
print("No requirements found in frontmatter.")

async def load_module_from_path(module_name, module_path):

try:
# Read the module content
with open(module_path, 'r') as file:
content = file.read()

# Parse frontmatter
frontmatter = {}
if content.startswith('"""'):
end = content.find('"""', 3)
if end != -1:
frontmatter_content = content[3:end]
frontmatter = parse_frontmatter(frontmatter_content)

# Install requirements if specified
if 'requirements' in frontmatter:
install_frontmatter_requirements(frontmatter['requirements'])

# Load the module
await install_requirements_from_file(module_path)
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
print(f"Loaded module: {module.__name__}")
logger.info(f"Loaded module: {module.__name__}")
if hasattr(module, "Pipeline"):
return module.Pipeline()
else:
raise Exception("No Pipeline class found")
except Exception as e:
print(f"Error loading module: {module_name}")
logger.error(f"Error loading module: {module_name}")

# Move the file to the error folder
failed_pipelines_folder = os.path.join(PIPELINES_DIR, "failed")
if not os.path.exists(failed_pipelines_folder):
os.makedirs(failed_pipelines_folder)

failed_file_path = os.path.join(failed_pipelines_folder, f"{module_name}.py")
os.rename(module_path, failed_file_path)
print(e)
logger.error(str(e))
return None


Expand All @@ -172,20 +144,25 @@ async def load_modules_from_directory(directory):
for filename in os.listdir(directory):
if filename.endswith(".py"):
module_name = filename[:-3] # Remove the .py extension

# Skip __init__.py files for pipeline loading
if module_name == "__init__":
continue

module_path = os.path.join(directory, filename)

# Create subfolder matching the filename without the .py extension
subfolder_path = os.path.join(directory, module_name)
if not os.path.exists(subfolder_path):
os.makedirs(subfolder_path)
logging.info(f"Created subfolder: {subfolder_path}")
logger.debug(f"Created subfolder: {subfolder_path}")

# Create a valves.json file if it doesn't exist
valves_json_path = os.path.join(subfolder_path, "valves.json")
if not os.path.exists(valves_json_path):
with open(valves_json_path, "w") as f:
json.dump({}, f)
logging.info(f"Created valves.json in: {subfolder_path}")
logger.debug(f"Created valves.json in: {subfolder_path}")

pipeline = await load_module_from_path(module_name, module_path)
if pipeline:
Expand All @@ -203,20 +180,26 @@ async def load_modules_from_directory(directory):
valves = ValvesModel(**combined_valves)
pipeline.valves = valves

logging.info(f"Updated valves for module: {module_name}")
logger.debug(f"Updated valves for module: {module_name}")

pipeline_id = pipeline.id if hasattr(pipeline, "id") else module_name
PIPELINE_MODULES[pipeline_id] = pipeline
PIPELINE_NAMES[pipeline_id] = module_name
logging.info(f"Loaded module: {module_name}")
logger.info(f"Loaded module: {module_name}")
else:
logging.warning(f"No Pipeline class found in {module_name}")
logger.warning(f"No Pipeline class found in {module_name}")

global PIPELINES
PIPELINES = get_all_pipelines()


async def on_startup():
if not os.path.exists(PIPELINES_DIR):
os.makedirs(PIPELINES_DIR)

await reset_pipelines_dir(PIPELINES_DIR, RESET_PIPELINES_DIR)
await install_requirements(PIPELINES_REQUIREMENTS_PATH)
await download_pipelines(PIPELINES_URLS, PIPELINES_DIR)
await load_modules_from_directory(PIPELINES_DIR)

for module in PIPELINE_MODULES.values():
Expand Down Expand Up @@ -277,7 +260,7 @@ async def check_url(request: Request, call_next):

@app.get("/v1/models")
@app.get("/models")
async def get_models():
async def get_models(user: str = Depends(get_current_user)):
"""
Returns the available pipelines
"""
Expand Down Expand Up @@ -354,29 +337,6 @@ class AddPipelineForm(BaseModel):
url: str


async def download_file(url: str, dest_folder: str):
filename = os.path.basename(urlparse(url).path)
if not filename.endswith(".py"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="URL must point to a Python file",
)

file_path = os.path.join(dest_folder, filename)

async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status != 200:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to download file",
)
with open(file_path, "wb") as f:
f.write(await response.read())

return file_path


@app.post("/v1/pipelines/add")
@app.post("/pipelines/add")
async def add_pipeline(
Expand All @@ -391,8 +351,8 @@ async def add_pipeline(
try:
url = convert_to_raw_url(form_data.url)

print(url)
file_path = await download_file(url, dest_folder=PIPELINES_DIR)
logger.debug(f"Downloading pipeline from {url}")
file_path = await download_file_to_folder(url, dest_folder=PIPELINES_DIR)
await reload()
return {
"status": True,
Expand Down Expand Up @@ -576,7 +536,7 @@ async def update_valves(pipeline_id: str, form_data: dict):
if hasattr(pipeline, "on_valves_updated"):
await pipeline.on_valves_updated()
except Exception as e:
print(e)
logger.error(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"{str(e)}",
Expand Down Expand Up @@ -610,7 +570,7 @@ async def filter_inlet(pipeline_id: str, form_data: FilterForm):
else:
return form_data.body
except Exception as e:
print(e)
logger.error(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"{str(e)}",
Expand Down Expand Up @@ -642,7 +602,7 @@ async def filter_outlet(pipeline_id: str, form_data: FilterForm):
else:
return form_data.body
except Exception as e:
print(e)
logger.error(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"{str(e)}",
Expand All @@ -665,13 +625,11 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
)

def job():
print(form_data.model)
logger.error(form_data.model)

pipeline = app.state.PIPELINES[form_data.model]
pipeline_id = form_data.model

print(pipeline_id)

if pipeline["type"] == "manifold":
manifold_id, pipeline_id = pipeline_id.split(".", 1)
pipe = PIPELINE_MODULES[manifold_id].pipe
Expand All @@ -688,11 +646,11 @@ def stream_content():
body=form_data.model_dump(),
)

logging.info(f"stream:true:{res}")
logger.info(f"stream:true:{res}")

if isinstance(res, str):
message = stream_message_template(form_data.model, res)
logging.info(f"stream_content:str:{message}")
logger.info(f"stream_content:str:{message}")
yield f"data: {json.dumps(message)}\n\n"

if isinstance(res, Iterator):
Expand All @@ -706,7 +664,7 @@ def stream_content():
except:
pass

logging.info(f"stream_content:Generator:{line}")
logger.info(f"stream_content:Generator:{line}")

if line.startswith("data:"):
yield f"{line}\n\n"
Expand Down Expand Up @@ -741,7 +699,7 @@ def stream_content():
messages=messages,
body=form_data.model_dump(),
)
logging.info(f"stream:false:{res}")
logger.info(f"stream:false:{res}")

if isinstance(res, dict):
return res
Expand All @@ -758,7 +716,7 @@ def stream_content():
for stream in res:
message = f"{message}{stream}"

logging.info(f"stream:false:{message}")
logger.info(f"stream:false:{message}")
return {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion",
Expand Down
4 changes: 2 additions & 2 deletions requirements-minimum.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@ PyJWT[crypto]

requests==2.32.2
aiohttp==3.9.5
httpx

gitpython==3.1.43
httpx
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ PyJWT[crypto]

requests==2.32.2
aiohttp==3.9.5
gitpython==3.1.43
httpx

# AI libraries
Expand Down
Loading