Source code for sparsify.app

# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import atexit
import logging
import os
from typing import Any, Union

from flasgger import Swagger
from flask import Flask
from flask_cors import CORS

from sparsezoo.utils import clean_path, create_dirs
from sparsify.blueprints import (
    errors_blueprint,
    jobs_blueprint,
    model_repo_blueprint,
    projects_benchmark_blueprint,
    projects_blueprint,
    projects_data_blueprint,
    projects_model_blueprint,
    projects_optim_blueprint,
    projects_profiles_blueprint,
    system_blueprint,
    ui_blueprint,
)
from sparsify.log import set_logging_level
from sparsify.models import database_setup
from sparsify.workers import JobWorkerManager


__all__ = ["run", "main"]

_LOGGER = logging.getLogger(__name__)


def _validate_working_dir(working_dir: str) -> str:
    if not working_dir:
        working_dir = os.getenv("NM_SERVER_WORKING_DIR", "")

    if not working_dir:
        working_dir = os.path.join("~", "sparsify")

    working_dir = clean_path(working_dir)

    try:
        create_dirs(working_dir)
    except Exception as err:
        raise RuntimeError(
            ("Error while trying to create sparsify " "working_dir at {}: {}").format(
                working_dir, err
            )
        )

    return working_dir


def _setup_logging(logging_level: str):
    try:
        logging_level = getattr(logging, logging_level)
    except Exception as err:
        _LOGGER.error(
            "error setting logging level to {}: {}".format(logging_level, err)
        )

    set_logging_level(logging_level)


def _blueprints_setup(app: Flask):
    app.register_blueprint(errors_blueprint)
    app.register_blueprint(jobs_blueprint)
    app.register_blueprint(model_repo_blueprint)
    app.register_blueprint(projects_blueprint)
    app.register_blueprint(projects_benchmark_blueprint)
    app.register_blueprint(projects_data_blueprint)
    app.register_blueprint(projects_model_blueprint)
    app.register_blueprint(projects_optim_blueprint)
    app.register_blueprint(projects_profiles_blueprint)
    app.register_blueprint(system_blueprint)
    app.register_blueprint(ui_blueprint)


def _api_docs_setup(app: Flask):
    Swagger(app)


def _worker_setup():
    manager = JobWorkerManager()

    def _interrupt():
        manager.shutdown()

    atexit.register(_interrupt)
    manager.start()


[docs]def run( working_dir: str, host: str, port: int, debug: bool, logging_level: str, ui_path: Union[str, None], ): working_dir = _validate_working_dir(working_dir) _setup_logging(logging_level) if ui_path is None: ui_path = os.path.join(os.path.dirname(clean_path(__file__)), "ui") app = Flask("sparsify", static_url_path="/unused") app.config["MAX_CONTENT_LENGTH"] = 2 * 1024 * 1024 * 1024 # 2 Gb limit app.config["UI_PATH"] = ui_path CORS(app) database_setup(working_dir, app) _blueprints_setup(app) _api_docs_setup(app) _worker_setup() app.run(host=host, port=port, debug=debug, threaded=True)
def parse_args() -> Any: parser = argparse.ArgumentParser(description="sparsify") parser.add_argument( "--working-dir", default=None, type=str, help="The path to the working directory to store state in, " "defaults to ~/sparsify", ) parser.add_argument( "--host", default="0.0.0.0", type=str, help="The host path to launch the server on", ) parser.add_argument( "--port", default=5543, type=int, help="The local port to launch the server on" ) parser.add_argument( "--debug", default=False, action="store_true", help="Set to run in debug mode", ) parser.add_argument( "--logging-level", default="INFO", type=str, help="The logging level to report at", ) parser.add_argument( "--ui-path", default=None, type=str, help="The directory to render the UI from, generally should not be set. " "By default, will load from the UI packaged with sparsify " "under sparsify/ui", ) return parser.parse_args()
[docs]def main(): ARGS = parse_args() run( ARGS.working_dir, ARGS.host, ARGS.port, ARGS.debug, ARGS.logging_level, ARGS.ui_path, )
if __name__ == "__main__": main()