Code for describing layers / operators in ML framework neural networks.

import json
from typing import Any, Dict, List, Tuple

from sparseml.utils import clean_path, create_parent_dirs

__all__ = ["AnalyzedLayerDesc"]

[docs]class AnalyzedLayerDesc(object): """ Description of an executed neural network layer. Contains information about the number of flops, shapes, params, etc. :param name: name of the layer :param type_: type of the layer :param params: number of parameters of the layer :param zeroed_params: number of parameters with values of zero :param prunable_params: number of parameters that could be pruned :param params_dims: dimensions of parameters :param prunable_params_dims: dimensions of prunable parameters :param execution_order: execution order of the layer/operation :param input_shape: shapes of input tensors :param output_shape: shapes of output tensors :param flops: Unused :param total_flops: total number of float operations """
[docs] @staticmethod def save_descs(descs: List, path: str): """ Save a list of AnalyzedLayerDesc to a json file :param descs: a list of descriptions to save :param path: the path to save the descriptions at """ path = clean_path(path) create_parent_dirs(path) save_obj = {"descriptions": [desc.dict() for desc in descs]} with open(path, "w") as file: json.dump(save_obj, file)
[docs] @staticmethod def load_descs(path: str) -> List: """ Load a list of AnalyzedLayerDesc from a json file :param path: the path to load the descriptions from :return: the loaded list of AnalyzedLayerDesc """ path = clean_path(path) with open(path, "r") as file: obj = json.load(file) descs = [] for desc_obj in obj["descriptions"]: desc_obj["type_"] = desc_obj["type"] del desc_obj["type"] del desc_obj["terminal"] del desc_obj["prunable"] descs.append(AnalyzedLayerDesc(**desc_obj)) return descs
[docs] @staticmethod def merge_descs(orig, descs: List): """ Merge a layer description with a list of others :param orig: original description :param descs: list of descriptions to merge with :return: a combined description """ merged = AnalyzedLayerDesc(, type_=orig.type_, params=orig.params, zeroed_params=orig.zeroed_params, prunable_params=orig.prunable_params, params_dims=orig.params_dims, prunable_params_dims=orig.prunable_params_dims, execution_order=orig.execution_order, input_shape=orig.input_shape, output_shape=orig.output_shape, flops=orig.flops, total_flops=orig.total_flops, stride=orig.stride, ) for desc in descs: merged.flops += desc.flops merged.total_flops += desc.total_flops merged.params += desc.params merged.prunable_params += desc.prunable_params merged.zeroed_params += desc.zeroed_params return merged
def __init__( self, name: str, type_: str, params: int = 0, zeroed_params: int = 0, prunable_params: int = 0, params_dims: Dict[str, Tuple[int, ...]] = None, prunable_params_dims: Dict[str, Tuple[int, ...]] = None, execution_order: int = -1, input_shape: Tuple[Tuple[int, ...], ...] = None, output_shape: Tuple[Tuple[int, ...], ...] = None, flops: int = 0, total_flops: int = 0, stride: Tuple[int, ...] = None, ): = name self.type_ = type_ self.params = params self.prunable_params = prunable_params self.zeroed_params = zeroed_params self.params_dims = params_dims self.prunable_params_dims = prunable_params_dims self.execution_order = execution_order self.input_shape = input_shape self.output_shape = output_shape self.flops = flops self.total_flops = total_flops self.stride = stride def __repr__(self): return "AnalyzedLayerDesc({})".format(self.dict()) @property def terminal(self) -> bool: """ :return: True if this is a terminal op, ie it is doing compute and is not a container, False otherwise """ return self.params_dims is not None @property def prunable(self) -> bool: """ :return: True if the layer supports kernel sparsity (is prunable), False otherwise """ return self.prunable_params > 0
[docs] def dict(self) -> Dict[str, Any]: """ :return: A serializable dictionary representation of the current instance """ return { "name":, "type": self.type_, "params": self.params, "zeroed_params": self.zeroed_params, "prunable_params": self.prunable_params, "params_dims": self.params_dims, "prunable_params_dims": self.prunable_params_dims, "execution_order": self.execution_order, "input_shape": self.input_shape, "output_shape": self.output_shape, "stride": self.stride, "flops": self.flops, "total_flops": self.total_flops, "terminal": self.terminal, "prunable": self.prunable, }