Skip to content

Commit

Permalink
adds support for local dev runners from CLI (#521)
Browse files Browse the repository at this point in the history
  • Loading branch information
zeiler authored Feb 18, 2025
1 parent aff0215 commit 821a0bb
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 8 deletions.
12 changes: 12 additions & 0 deletions clarifai/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,18 @@ def run_locally(model_path, port, mode, keep_env, keep_image):
click.echo(f"Failed to starts model server locally: {e}", err=True)


@model.command()
@click.option(
'--model_path',
type=click.Path(exists=True),
required=True,
help='Path to the model directory.')
def local_dev(model_path):
"""Run the model as a local dev runner to help debug your model connected to the API. You must set several envvars such as CLARIFAI_PAT, CLARIFAI_RUNNER_ID, CLARIFAI_NODEPOOL_ID, CLARIFAI_COMPUTE_CLUSTER_ID. """
from clarifai.runners.server import serve
serve(model_path)


@model.command()
@click.option(
'--config',
Expand Down
29 changes: 21 additions & 8 deletions clarifai/runners/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,30 +68,43 @@ def main():

parsed_args = parser.parse_args()

builder = ModelBuilder(parsed_args.model_path, download_validation_only=True)
serve(parsed_args.model_path, parsed_args.port, parsed_args.pool_size,
parsed_args.max_queue_size, parsed_args.max_msg_length, parsed_args.enable_tls,
parsed_args.grpc)


def serve(model_path,
port=8000,
pool_size=32,
max_queue_size=10,
max_msg_length=1024 * 1024 * 1024,
enable_tls=False,
grpc=False):

builder = ModelBuilder(model_path, download_validation_only=True)

model = builder.create_model_instance()

# Setup the grpc server for local development.
if parsed_args.grpc:
if grpc:

# initialize the servicer with the runner so that it gets the predict(), generate(), stream() classes.
servicer = ModelServicer(model)

server = GRPCServer(
futures.ThreadPoolExecutor(
max_workers=parsed_args.pool_size,
max_workers=pool_size,
thread_name_prefix="ServeCalls",
),
parsed_args.max_msg_length,
parsed_args.max_queue_size,
max_msg_length,
max_queue_size,
)
server.add_port_to_server('[::]:%s' % parsed_args.port, parsed_args.enable_tls)
server.add_port_to_server('[::]:%s' % port, enable_tls)

service_pb2_grpc.add_V2Servicer_to_server(servicer, server)
server.start()
logger.info("Started server on port %s", parsed_args.port)
logger.info(f"Access the model at http://localhost:{parsed_args.port}")
logger.info("Started server on port %s", port)
logger.info(f"Access the model at http://localhost:{port}")
server.wait_for_termination()
else: # start the runner with the proper env variables and as a runner protocol.

Expand Down

0 comments on commit 821a0bb

Please sign in to comment.