Source code for sparsify.models.base

# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Base DB model classes for the server
"""

import datetime
import logging

from peewee import DateTimeField, Model, TextField
from playhouse.sqlite_ext import JSONField
from playhouse.sqliteq import SqliteQueueDatabase


__all__ = [
    "FileStorage",
    "database",
    "storage",
    "BaseModel",
    "BaseCreatedModifiedModel",
    "ListObjField",
    "CSVField",
    "CSVIntField",
    "CSVFloatField",
]

_LOGGER = logging.getLogger(__name__)


[docs]class FileStorage(object): """ Class for handling local file storage and the path that is located at. Used for storing large files that would not be good in the DB such as model and data files. """ def __init__(self): self._root_path = None @property def root_path(self) -> str: """ :return: the root path on the local file system for where to store files """ self._validate_setup() return self._root_path
[docs] def init(self, root_path: str): """ Initialize the file storage class for a given path :param root_path: the root path on the local file system for where to store files """ self._root_path = root_path
def _validate_setup(self): if self._root_path is None: raise ValueError("root_path is not set, call init first")
database = SqliteQueueDatabase( None, use_gevent=False, autostart=False, queue_max_size=128, results_timeout=30, ) storage = FileStorage()
[docs]class BaseModel(Model): """ Base peewee model all DB models must extend from """ class Meta(object): database = database storage = storage
[docs] def refresh(self): """ Refresh the data for the model instance from the DB """ return type(self).get_by_id(self._pk)
[docs]class BaseCreatedModifiedModel(BaseModel): """ Base peewee model that includes created and modified timestamp functionality """ created = DateTimeField(default=datetime.datetime.now) modified = DateTimeField(default=datetime.datetime.now)
[docs] def save(self, *args, **kwargs): self.modified = datetime.datetime.now() return super().save(*args, **kwargs)
[docs]class ListObjField(JSONField): """ Field for handling lists of objects in a peewee database """
[docs] def db_value(self, value): if value: value = {"list": value} return super().db_value(value)
[docs] def python_value(self, value): value = super().python_value(value) return value["list"] if value else []
[docs]class CSVField(TextField): """ CSV field for handling lists of strings in a peewee database """
[docs] def db_value(self, value): if value: value = ",".join([str(v) for v in value]) return value
[docs] def python_value(self, value): if value is None: return None return value.split(",") if value else []
[docs]class CSVIntField(CSVField): """ CSV field for handling lists of integers in a peewee database """
[docs] def python_value(self, value): return [int(v) for v in value.split(",")] if value else []
[docs]class CSVFloatField(CSVField): """ CSV field for handling lists of floats in a peewee database """
[docs] def python_value(self, value): return [float(v) for v in value.split(",")] if value else []