-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
80 lines (66 loc) · 2.5 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# ===========================
# Module: FastAPI setup
# Author: Kenneth Leung
# Last Modified: 6 Dec 2021
# ===========================
# Command to execute script: uvicorn main:app --host=0.0.0.0 --port=8000
import pandas as pd
import io
import h2o
from fastapi import FastAPI, File
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse, HTMLResponse
import mlflow
import mlflow.h2o
from mlflow.tracking import MlflowClient
from mlflow.entities import ViewType
from utils.data_processing import match_col_types, separate_id_col
# Create FastAPI instance
app = FastAPI()
# Initiate H2O instance and MLflow client
h2o.init()
client = MlflowClient()
# Load best model (based on logloss) amongst all runs in all experiments
all_exps = [exp.experiment_id for exp in client.list_experiments()]
runs = mlflow.search_runs(experiment_ids=all_exps, run_view_type=ViewType.ALL)
run_id, exp_id = runs.loc[runs['metrics.log_loss'].idxmin()]['run_id'], runs.loc[runs['metrics.log_loss'].idxmin()]['experiment_id']
print(f'Loading best model: Run {run_id} of Experiment {exp_id}')
best_model = mlflow.h2o.load_model(f"mlruns/{exp_id}/{run_id}/artifacts/model/")
# Create POST endpoint with path '/predict'
@app.post("/predict")
async def predict(file: bytes = File(...)):
print('[+] Initiate Prediction')
file_obj = io.BytesIO(file)
test_df = pd.read_csv(file_obj)
test_h2o = h2o.H2OFrame(test_df)
# Separate ID column (if any)
id_name, X_id, X_h2o = separate_id_col(test_h2o)
# Match test set col types with train set
X_h2o = match_col_types(X_h2o)
# Generate predictions with best model (output is H2O frame)
preds = best_model.predict(X_h2o)
# Apply processing if dataset has ID column
if id_name is not None:
preds_list = preds.as_data_frame()['predict'].tolist()
id_list = X_id.as_data_frame()[id_name].tolist()
preds_final = dict(zip(id_list, preds_list))
else:
preds_final = preds.as_data_frame()['predict'].tolist()
json_compatible_item_data = jsonable_encoder(preds_final)
return JSONResponse(content=json_compatible_item_data)
@app.get("/")
async def main():
content = """
<body>
<h2> Welcome to the End to End AutoML Project
</body>
"""
# content = """
# <body>
# <form action="/predict/" enctype="multipart/form-data" method="post">
# <input name="file" type="file" multiple>
# <input type="submit">
# </form>
# </body>
# """
return HTMLResponse(content=content)