Source code for sparseml.utils.restricted_eval

# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Restricted eval function for safely evaluating equations in recipes
"""


import ast
import operator
from typing import Any, Dict, Optional


__all__ = [
    "restricted_eval",
    "UnknownVariableException",
]


[docs]class UnknownVariableException(Exception): """ Exception raised for known variable names in restricted eval :param var_name: name of unknown variable """ def __init__(self, var_name: str): self.var_name = var_name super().__init__(f"Unknown variable name in eval: {var_name}")
[docs]def restricted_eval( expression: str, variables: Optional[Dict[str, float]] = None, ) -> float: """ :param expression: expression to evaluate :param variables: dictionary of string variables to float values that may be included in the expression :return: evaluated expression. Only supported operations, numbers, and float variables named in the variables dict may be included :raises: RuntimeError if any unsupported operations are included, UnknownVariableException if any variables not included in the variables dict are given """ variables = variables or {} return _restricted_eval_node(ast.parse(expression.strip()).body[0], variables)
_VALID_BINOPS_TO_EVAL = { ast.Add: operator.add, ast.Sub: operator.sub, ast.Mult: operator.mul, ast.Div: operator.truediv, ast.FloorDiv: operator.floordiv, ast.Pow: operator.pow, ast.Mod: operator.mod, } _VALID_UOPS_TO_EVAL = {ast.USub: operator.neg} _VALID_FUNCTIONS_TO_EVAL = { "abs": abs, "float": float, "int": int, "min": min, "max": max, "round": round, } def _restricted_eval_node(node: Any, variables: Dict[str, float]) -> float: if isinstance(node, ast.Expr): return _restricted_eval_node(node.value, variables) if isinstance(node, ast.Num): return node.n if isinstance(node, ast.Name): if node.id in variables: return variables[node.id] else: raise UnknownVariableException(node.id) if isinstance(node, ast.BinOp): op_type = type(node.op) if op_type in _VALID_BINOPS_TO_EVAL: return _VALID_BINOPS_TO_EVAL[op_type]( _restricted_eval_node(node.left, variables), _restricted_eval_node(node.right, variables), ) else: raise RuntimeError(f"Unsupported binary operator type {op_type}") if isinstance(node, ast.UnaryOp): op_type = type(node.op) if op_type in _VALID_UOPS_TO_EVAL: return _VALID_UOPS_TO_EVAL[op_type]( _restricted_eval_node(node.left, variables), ) else: raise RuntimeError(f"Unsupported binary operator type {op_type}") if isinstance(node, ast.Call): func_name = node.func.id if func_name in _VALID_FUNCTIONS_TO_EVAL: args = [_restricted_eval_node(arg, variables) for arg in node.args] kwargs = { kwarg.arg: _restricted_eval_node(kwarg.value, variables) for kwarg in node.keywords } return _VALID_FUNCTIONS_TO_EVAL[func_name](*args, **kwargs) else: raise RuntimeError(f"Unsupported function name {func_name}") raise RuntimeError(f"Unsupported AST node type {type(node)}")