test_utils.py 2.26 KB
Newer Older
Stelios Karozis's avatar
Stelios Karozis committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
# Copyright (c) 2013-2014 Google, Inc.
# Copyright (c) 2014 LOGILAB S.A. (Paris, FRANCE) <contact@logilab.fr>
# Copyright (c) 2015-2016, 2018 Claudiu Popa <pcmanticore@gmail.com>
# Copyright (c) 2015-2016 Ceridwen <ceridwenv@gmail.com>
# Copyright (c) 2016 Jakub Wilk <jwilk@jwilk.net>
# Copyright (c) 2018 Anthony Sottile <asottile@umich.edu>

# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER

"""Utility functions for test code that uses astroid ASTs as input."""
import contextlib
import functools
import sys
import warnings

import pytest

from astroid import nodes


def require_version(minver=None, maxver=None):
    """ Compare version of python interpreter to the given one. Skip the test
    if older.
    """

    def parse(string, default=None):
        string = string or default
        try:
            return tuple(int(v) for v in string.split("."))
        except ValueError as exc:
            raise ValueError(
                "{string} is not a correct version : should be X.Y[.Z].".format(
                    string=string
                )
            ) from exc

    def check_require_version(f):
        current = sys.version_info[:3]
        if parse(minver, "0") < current <= parse(maxver, "4"):
            return f

        str_version = ".".join(str(v) for v in sys.version_info)

        @functools.wraps(f)
        def new_f(*args, **kwargs):
            if minver is not None:
                pytest.skip(
                    "Needs Python > %s. Current version is %s." % (minver, str_version)
                )
            elif maxver is not None:
                pytest.skip(
                    "Needs Python <= %s. Current version is %s." % (maxver, str_version)
                )

        return new_f

    return check_require_version


def get_name_node(start_from, name, index=0):
    return [n for n in start_from.nodes_of_class(nodes.Name) if n.name == name][index]


@contextlib.contextmanager
def enable_warning(warning):
    warnings.simplefilter("always", warning)
    try:
        yield
    finally:
        # Reset it to default value, so it will take
        # into account the values from the -W flag.
        warnings.simplefilter("default", warning)