file_name.py 5.57 KB
Newer Older
Stelios Karozis's avatar
Stelios Karozis committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
import os

from jedi._compatibility import FileNotFoundError, force_unicode, scandir
from jedi.api import classes
from jedi.api.strings import StringName, get_quote_ending
from jedi.api.helpers import match
from jedi.inference.helpers import get_str_or_none


class PathName(StringName):
    api_type = u'path'


def complete_file_name(inference_state, module_context, start_leaf, quote, string,
                       like_name, signatures_callback, code_lines, position, fuzzy):
    # First we want to find out what can actually be changed as a name.
    like_name_length = len(os.path.basename(string))

    addition = _get_string_additions(module_context, start_leaf)
    if string.startswith('~'):
        string = os.path.expanduser(string)
    if addition is None:
        return
    string = addition + string

    # Here we use basename again, because if strings are added like
    # `'foo' + 'bar`, it should complete to `foobar/`.
    must_start_with = os.path.basename(string)
    string = os.path.dirname(string)

    sigs = signatures_callback(*position)
    is_in_os_path_join = sigs and all(s.full_name == 'os.path.join' for s in sigs)
    if is_in_os_path_join:
        to_be_added = _add_os_path_join(module_context, start_leaf, sigs[0].bracket_start)
        if to_be_added is None:
            is_in_os_path_join = False
        else:
            string = to_be_added + string
    base_path = os.path.join(inference_state.project._path, string)
    try:
        listed = sorted(scandir(base_path), key=lambda e: e.name)
        # OSError: [Errno 36] File name too long: '...'
    except (FileNotFoundError, OSError):
        return
    quote_ending = get_quote_ending(quote, code_lines, position)
    for entry in listed:
        name = entry.name
        if match(name, must_start_with, fuzzy=fuzzy):
            if is_in_os_path_join or not entry.is_dir():
                name += quote_ending
            else:
                name += os.path.sep

            yield classes.Completion(
                inference_state,
                PathName(inference_state, name[len(must_start_with) - like_name_length:]),
                stack=None,
                like_name_length=like_name_length,
                is_fuzzy=fuzzy,
            )


def _get_string_additions(module_context, start_leaf):
    def iterate_nodes():
        node = addition.parent
        was_addition = True
        for child_node in reversed(node.children[:node.children.index(addition)]):
            if was_addition:
                was_addition = False
                yield child_node
                continue

            if child_node != '+':
                break
            was_addition = True

    addition = start_leaf.get_previous_leaf()
    if addition != '+':
        return ''
    context = module_context.create_context(start_leaf)
    return _add_strings(context, reversed(list(iterate_nodes())))


def _add_strings(context, nodes, add_slash=False):
    string = ''
    first = True
    for child_node in nodes:
        values = context.infer_node(child_node)
        if len(values) != 1:
            return None
        c, = values
        s = get_str_or_none(c)
        if s is None:
            return None
        if not first and add_slash:
            string += os.path.sep
        string += force_unicode(s)
        first = False
    return string


def _add_os_path_join(module_context, start_leaf, bracket_start):
    def check(maybe_bracket, nodes):
        if maybe_bracket.start_pos != bracket_start:
            return None

        if not nodes:
            return ''
        context = module_context.create_context(nodes[0])
        return _add_strings(context, nodes, add_slash=True) or ''

    if start_leaf.type == 'error_leaf':
        # Unfinished string literal, like `join('`
        value_node = start_leaf.parent
        index = value_node.children.index(start_leaf)
        if index > 0:
            error_node = value_node.children[index - 1]
            if error_node.type == 'error_node' and len(error_node.children) >= 2:
                index = -2
                if error_node.children[-1].type == 'arglist':
                    arglist_nodes = error_node.children[-1].children
                    index -= 1
                else:
                    arglist_nodes = []

                return check(error_node.children[index + 1], arglist_nodes[::2])
        return None

    # Maybe an arglist or some weird error case. Therefore checked below.
    searched_node_child = start_leaf
    while searched_node_child.parent is not None \
            and searched_node_child.parent.type not in ('arglist', 'trailer', 'error_node'):
        searched_node_child = searched_node_child.parent

    if searched_node_child.get_first_leaf() is not start_leaf:
        return None
    searched_node = searched_node_child.parent
    if searched_node is None:
        return None

    index = searched_node.children.index(searched_node_child)
    arglist_nodes = searched_node.children[:index]
    if searched_node.type == 'arglist':
        trailer = searched_node.parent
        if trailer.type == 'error_node':
            trailer_index = trailer.children.index(searched_node)
            assert trailer_index >= 2
            assert trailer.children[trailer_index - 1] == '('
            return check(trailer.children[trailer_index - 1], arglist_nodes[::2])
        elif trailer.type == 'trailer':
            return check(trailer.children[0], arglist_nodes[::2])
    elif searched_node.type == 'trailer':
        return check(searched_node.children[0], [])
    elif searched_node.type == 'error_node':
        # Stuff like `join(""`
        return check(arglist_nodes[-1], [])