From aa18c52226ddf3d5a3448702fc2b1a14b0ee3c2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zbigniew=20J=C4=99drzejewski-Szmek?= Date: Wed, 16 Aug 2023 17:23:36 +0200 Subject: [PATCH 3/7] Revert "Add in protections against call to `eval(expression)`" This reverts commit 4b2d89cf14e75030d27629925b9998e1e91d23c7. --- numexpr/necompiler.py | 26 +++++++----------- numexpr/tests/test_numexpr.py | 50 ++++------------------------------- 2 files changed, 15 insertions(+), 61 deletions(-) diff --git a/numexpr/necompiler.py b/numexpr/necompiler.py index fef886baf5..37052acadb 100644 --- a/numexpr/necompiler.py +++ b/numexpr/necompiler.py @@ -13,7 +13,6 @@ import __future__ import sys import numpy import threading -import re is_cpu_amd_intel = False # DEPRECATION WARNING: WILL BE REMOVED IN FUTURE RELEASE from numexpr import interpreter, expressions, use_vml @@ -260,17 +259,10 @@ class Immediate(Register): def __str__(self): return 'Immediate(%d)' % (self.node.value,) -_forbidden_re = re.compile('[\;[\:]|__') + def stringToExpression(s, types, context): """Given a string, convert it to a tree of ExpressionNode's. """ - # sanitize the string for obvious attack vectors that NumExpr cannot - # parse into its homebrew AST. This is to protect the call to `eval` below. - # We forbid `;`, `:`. `[` and `__` - # We would like to forbid `.` but it is both a reference and decimal point. - if _forbidden_re.search(s) is not None: - raise ValueError(f'Expression {s} has forbidden control characters.') - old_ctx = expressions._context.get_current_context() try: expressions._context.set_new_context(context) @@ -293,10 +285,8 @@ def stringToExpression(s, types, context): t = types.get(name, default_type) names[name] = expressions.VariableNode(name, type_to_kind[t]) names.update(expressions.functions) - # now build the expression ex = eval(c, names) - if expressions.isConstant(ex): ex = expressions.ConstantNode(ex, expressions.getKind(ex)) elif not isinstance(ex, expressions.ExpressionNode): @@ -621,7 +611,9 @@ def NumExpr(ex, signature=(), **kwargs): Returns a `NumExpr` object containing the compiled function. """ - + # NumExpr can be called either directly by the end-user, in which case + # kwargs need to be sanitized by getContext, or by evaluate, + # in which case kwargs are in already sanitized. # In that case _frame_depth is wrong (it should be 2) but it doesn't matter # since it will not be used (because truediv='auto' has already been # translated to either True or False). @@ -766,7 +758,7 @@ def getArguments(names, local_dict=None, global_dict=None, _frame_depth: int=2): _names_cache = CacheDict(256) _numexpr_cache = CacheDict(256) _numexpr_last = {} -_numexpr_sanity = set() + evaluate_lock = threading.Lock() # MAYBE: decorate this function to add attributes instead of having the @@ -869,7 +861,7 @@ def evaluate(ex: str, out: numpy.ndarray = None, order: str = 'K', casting: str = 'safe', - _frame_depth: int = 3, + _frame_depth: int=3, **kwargs) -> numpy.ndarray: """ Evaluate a simple array expression element-wise using the virtual machine. @@ -917,8 +909,6 @@ def evaluate(ex: str, _frame_depth: int The calling frame depth. Unless you are a NumExpr developer you should not set this value. - - """ # We could avoid code duplication if we called validate and then re_evaluate # here, but they we have difficulties with the `sys.getframe(2)` call in @@ -931,6 +921,10 @@ def evaluate(ex: str, else: raise e + + + + def re_evaluate(local_dict: Optional[Dict] = None, _frame_depth: int=2) -> numpy.ndarray: """ diff --git a/numexpr/tests/test_numexpr.py b/numexpr/tests/test_numexpr.py index ebc41c8d54..ccb0b6cb07 100644 --- a/numexpr/tests/test_numexpr.py +++ b/numexpr/tests/test_numexpr.py @@ -373,9 +373,8 @@ class test_evaluate(TestCase): a1 = array([1., 2., 3.]) b1 = array([4., 5., 6.]) c1 = array([7., 8., 9.]) - local_dict={'a': a1, 'b': b1, 'c': c1} - x = evaluate("2*a + 3*b*c", local_dict=local_dict) - x = re_evaluate(local_dict=local_dict) + x = evaluate("2*a + 3*b*c", local_dict={'a': a1, 'b': b1, 'c': c1}) + x = re_evaluate() assert_array_equal(x, array([86., 124., 168.])) def test_validate(self): @@ -401,10 +400,9 @@ class test_evaluate(TestCase): a1 = array([1., 2., 3.]) b1 = array([4., 5., 6.]) c1 = array([7., 8., 9.]) - local_dict={'a': a1, 'b': b1, 'c': c1} - retval = validate("2*a + 3*b*c", local_dict=local_dict) + retval = validate("2*a + 3*b*c", local_dict={'a': a1, 'b': b1, 'c': c1}) assert(retval is None) - x = re_evaluate(local_dict=local_dict) + x = re_evaluate() assert_array_equal(x, array([86., 124., 168.])) # Test for issue #22 @@ -504,49 +502,11 @@ class test_evaluate(TestCase): a = arange(3) try: evaluate("a < [0, 0, 0]") - except (ValueError, TypeError): + except TypeError: pass else: self.fail() - def test_forbidden_tokens(self): - # Forbid dunder - try: - evaluate('__builtins__') - except ValueError: - pass - else: - self.fail() - - # Forbid colon for lambda funcs - try: - evaluate('lambda x: x') - except ValueError: - pass - else: - self.fail() - - # Forbid indexing - try: - evaluate('locals()[]') - except ValueError: - pass - else: - self.fail() - - # Forbid semicolon - try: - evaluate('import os; os.cpu_count()') - except ValueError: - pass - else: - self.fail() - - # I struggle to come up with cases for our ban on `'` and `"` - - - - def test_disassemble(self): assert_equal(disassemble(NumExpr( "where(m, a, -1)", [('m', bool), ('a', float)])),