From 07712995eb0c682aea313184219e03045295a79c Mon Sep 17 00:00:00 2001 From: ronilpatil Date: Mon, 18 Mar 2024 11:25:17 +0530 Subject: [PATCH] add functionality in api --- prod/api.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/prod/api.py b/prod/api.py index ae3a0a0..025f57f 100644 --- a/prod/api.py +++ b/prod/api.py @@ -30,8 +30,11 @@ class WineqIp(BaseModel) : params = yaml.safe_load(open(f'{home_dir.as_posix()}/params.yaml')) mlflow.set_tracking_uri(params['mlflow_config']['mlflow_tracking_uri']) client = MlflowClient() +# fetch model by model version model_details = client.get_model_version_by_alias(name = params['mlflow_config']['reg_model_name'], alias = params['mlflow_config']['stage']) -model = load_model(f'models:/{model_details.name}/{model_details.version}') +# model = load_model(f'models:/{model_details.name}/{model_details.version}') +# fetch model by model alias +model = load_model(f"models:/{model_details.name}@{params['mlflow_config']['stage']}") def feat_gen(user_input: dict) -> dict : user_input['total_acidity'] = user_input['fixed_acidity'] + user_input['volatile_acidity'] + user_input['citric_acid'] @@ -54,9 +57,14 @@ def feat_gen(user_input: dict) -> dict : return user_input + +@app.get('/') +def root() : + return {'api status': 'up & running'} + # Define API route to make predictions @app.post('/predict') -async def predict_wineq(input_data: WineqIp) : +def predict_wineq(input_data: WineqIp) : processed_data = feat_gen(input_data.model_dump()) data = [[processed_data['fixed_acidity'], processed_data['volatile_acidity'], processed_data['citric_acid'], processed_data['residual_sugar'], @@ -70,15 +78,20 @@ async def predict_wineq(input_data: WineqIp) : # and numpy array is not compatible with any web framework # I stuck here for 3 hrs, this error was damn... predictions = model.predict(data).tolist() - pred_probability = model.predict_proba(data).tolist() + pred_probability = [round(i, 5) for i in model.predict_proba(data).tolist()[0]] return {'predictions': predictions, 'probability': pred_probability} +@app.get('/details') +def get_details() : + return {'model name': model_details.name, 'model version': model_details.version, + 'model aliases': model_details.aliases, 'model run_id': model_details.run_id, + 'model description': model_details.description, 'model tags': model_details.tags} + if __name__ == '__main__' : import uvicorn uvicorn.run('prod.api:app', host = '127.0.0.1', port = 8000, log_level = 'debug', proxy_headers = True, reload = True) - # app_name: api (file name) # port: 8000 (default) # cmd: uvicorn prod.api:app --reload