Skip to content

Commit

Permalink
add functionality in api
Browse files Browse the repository at this point in the history
  • Loading branch information
ronylpatil committed Mar 18, 2024
1 parent 86e181f commit 0771299
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions prod/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,11 @@ class WineqIp(BaseModel) :
params = yaml.safe_load(open(f'{home_dir.as_posix()}/params.yaml'))
mlflow.set_tracking_uri(params['mlflow_config']['mlflow_tracking_uri'])
client = MlflowClient()
# fetch model by model version
model_details = client.get_model_version_by_alias(name = params['mlflow_config']['reg_model_name'], alias = params['mlflow_config']['stage'])
model = load_model(f'models:/{model_details.name}/{model_details.version}')
# model = load_model(f'models:/{model_details.name}/{model_details.version}')
# fetch model by model alias
model = load_model(f"models:/{model_details.name}@{params['mlflow_config']['stage']}")

def feat_gen(user_input: dict) -> dict :
user_input['total_acidity'] = user_input['fixed_acidity'] + user_input['volatile_acidity'] + user_input['citric_acid']
Expand All @@ -54,9 +57,14 @@ def feat_gen(user_input: dict) -> dict :

return user_input


@app.get('/')
def root() :
return {'api status': 'up & running'}

# Define API route to make predictions
@app.post('/predict')
async def predict_wineq(input_data: WineqIp) :
def predict_wineq(input_data: WineqIp) :
processed_data = feat_gen(input_data.model_dump())

data = [[processed_data['fixed_acidity'], processed_data['volatile_acidity'], processed_data['citric_acid'], processed_data['residual_sugar'],
Expand All @@ -70,15 +78,20 @@ async def predict_wineq(input_data: WineqIp) :
# and numpy array is not compatible with any web framework
# I stuck here for 3 hrs, this error was damn...
predictions = model.predict(data).tolist()
pred_probability = model.predict_proba(data).tolist()
pred_probability = [round(i, 5) for i in model.predict_proba(data).tolist()[0]]
return {'predictions': predictions, 'probability': pred_probability}

@app.get('/details')
def get_details() :
return {'model name': model_details.name, 'model version': model_details.version,
'model aliases': model_details.aliases, 'model run_id': model_details.run_id,
'model description': model_details.description, 'model tags': model_details.tags}

if __name__ == '__main__' :
import uvicorn
uvicorn.run('prod.api:app', host = '127.0.0.1', port = 8000, log_level = 'debug',
proxy_headers = True, reload = True)


# app_name: api (file name)
# port: 8000 (default)
# cmd: uvicorn prod.api:app --reload

0 comments on commit 0771299

Please sign in to comment.