import _ast
class SafeEvalError(Exception):
pass
class UnsafeCode(SafeEvalError):
pass
# safe types:
# Sequences:
# list, tuple, dict, set, frozen_set*
# Literals: str, unicode, int, long, complex, float
def safe_eval(text):
"similar to eval, but only works on literals"
ast = compile(text, "<string>", 'exec', _ast.PyCF_ONLY_AST)
return _traverse(ast.body[0].value)
def _traverse(ast):
if isinstance(ast, _ast.List):
return [_traverse(el) for el in ast.elts]
elif isinstance(ast, _ast.Tuple):
return tuple(_traverse(el) for el in ast.elts)
elif isinstance(ast, _ast.Dict):
return dict(
zip(
(_traverse(k) for k in ast.keys),
(_traverse(v) for v in ast.values)
)
)
elif isinstance(ast, _ast.Str):
return ast.s
elif isinstance(ast, _ast.Num):
return ast.n
elif isinstance(ast, _ast.Expr):
return _traverse(ast.value)
elif isinstance(ast, _ast.BinOp):
if isinstance(ast.op, _ast.Add):
return _traverse(ast.left) + _traverse(ast.right)
elif isinstance(ast.op, _ast.Sub):
return _traverse(ast.left) - _traverse(ast.right)
elif isinstance(ast.op, _ast.Div):
return _traverse(ast.left) / _traverse(ast.right)
elif isinstance(ast.op, _ast.FloorDiv):
return _traverse(ast.left) // _traverse(ast.right)
elif isinstance(ast.op, _ast.Mod):
return _traverse(ast.left) % _traverse(ast.right)
elif isinstance(ast.op, _ast.Mult):
return _traverse(ast.left) * _traverse(ast.right)
elif isinstance(ast.op, _ast.Pow):
return _traverse(ast.left) ** _traverse(ast.right)
elif isinstance(ast.op, _ast.BitAnd):
return _traverse(ast.left) & _traverse(ast.right)
elif isinstance(ast.op, _ast.BitOr):
return _traverse(ast.left) | _traverse(ast.right)
elif isinstance(ast.op, _ast.BitXor):
return _traverse(ast.left) ^ _traverse(ast.right)
elif isinstance(ast.op, _ast.LShift):
return _traverse(ast.left) << _traverse(ast.right)
elif isinstance(ast.op, _ast.RShift):
return _traverse(ast.left) >> _traverse(ast.right)
elif isinstance(ast, _ast.BoolOp):
if isinstance(ast.op, _ast.And):
return all(_traverse(v) for v in ast.values)
if isinstance(ast.op, _ast.Or):
return any(_traverse(v) for v in ast.values)
elif isinstance(ast, _ast.UnaryOp):
if isinstance(ast.op, _ast.Invert):
return _traverse(ast.operand)
if isinstance(ast.op, _ast.USub):
return -_traverse(ast.operand)
if isinstance(ast.op, _ast.UAdd):
return +_traverse(ast.operand)
if isinstance(ast.op, _ast.Not):
return not _traverse(ast.operand)
raise UnsafeCode()
if __name__ == "__main__":
print safe_eval("[1,2,3,{'hello':1}, (1,-2,3)], 4j, 1+5j, ~1+2*3")