nvim-config

Log | Files | Refs | Submodules | README

commit cd070614ce2c018cdae8bdd9a56b1824fc1a7cb9
parent 3433819c69837aafd9661414e57fc088cf4766dd
Author: Thomas Vigouroux <thomas.vigouroux@protonmail.com>
Date:   Sun,  3 Apr 2022 15:11:31 +0200

feat: add static lua stack analyzer

Diffstat:
Alua/lua-stack-check.lua | 560+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 560 insertions(+), 0 deletions(-)

diff --git a/lua/lua-stack-check.lua b/lua/lua-stack-check.lua @@ -0,0 +1,560 @@ +local M = {} + +local api = vim.api +local ts = vim.treesitter +local tsq = require'vim.treesitter.query' + +local QUERY = [[ +;; Upvalues are defined before the function definition +((comment) @up_comment + . + (_ (function_declarator declarator: (identifier) @fname) + body: (_)) @body + (#lua-match? @up_comment "^//s ")) + +;; Stack arguments are defined in the first comment of the function body +(function_definition + declarator: (_ declarator: (identifier) @fname) + body: (compound_statement . (comment) @before_comment) + (#lua-match? @before_comment "^//s ")) @body + +;; Return stack state is defined in the last comment of the function body +(function_definition + declarator: (_ declarator: (identifier) @fname) + body: (compound_statement (comment) @after_comment . ) + (#lua-match? @after_comment "^//s ")) @body +]] +local query = tsq.parse_query('c', QUERY) + +local function parse_stack(comment) + comment = comment:sub(5) + -- First collect all items + local items = {} + local init = 1 + while true do + local start, stop = string.find(comment, "[a-z_]+", init) + if start then + table.insert(items, comment:sub(start, stop)) + init = stop + 1 + else + break + end + end + return items +end + +------------------ +-- Type +------------------ + +local Type = { + UNKNOWN = 0, + STRING = 1, + INTEGER = 2, + BOOLEAN = 3, + TABLE = 4, + NIL = 5, + USERDATA = 6, +} +vim.tbl_add_reverse_lookup(Type) + +M.Type = Type + +------------------ +-- Element +------------------ + +local Element = {} +Element.__index = Element + +function Element.new(name, type) + local self = setmetatable({ + _name = name or "", + _type = type or Type.UNKNOWN + }, Element) + return self +end + +function Element.nil_() + return Element.new("nil", Type.NIL) +end + +function Element.anon(type) + return Element.new("", type) +end + +M.Element = Element + +------------------ +-- Stack +------------------ + +local Stack = {} +Stack.__index = Stack + +function Stack.new(items) + local inner = {} + if items then + for _, elem in pairs(parse_stack(items)) do + local type = Type.UNKNOWN + if elem == "nil" then + type = Type.NIL + end + table.insert(inner, Element.new(elem, type)) + end + end + local self = setmetatable({ + _inner = inner, + }, Stack) + return self +end + +function Stack:copy() + return setmetatable(vim.deepcopy(self), Stack) +end + +-- Stack manipulation + +function Stack:insert(...) + table.insert(self._inner, ...) +end + +function Stack:push(element) + self:insert(element) +end + +function Stack:push_all(other) + for _,elem in pairs(other._inner) do + self:push(elem) + end +end + +function Stack:remove(...) + return table.remove(self._inner,...) +end + +function Stack:pop() + return self:remove() +end + +function Stack:get(index) + if index < 0 then + return self._inner[#self._inner + index + 1] + else + return self._inner[index] + end +end + +function Stack:print() + local to_print = {} + for index,element in ipairs(self._inner) do + to_print[index] = string.format("%s: %s", element._name, Type[element._type]) + end + print(vim.inspect(to_print)) +end + +function Stack:match(other, strict) + local short, long + + local function report() + print("Stacks not matching...") + self:print() + other:print() + end + + if strict then + short = self._inner + long = other._inner + if #short ~= #long then + report() + return false + end + else + if #self._inner > #other._inner then + short = other._inner + long = self._inner + else + short = self._inner + long = other._inner + end + end + + for i=1,#short do + local left = short[i] + local right = long[#long - #short + i] + + local function sync_names() + local left_has_name = #(left._name) + local right_has_name = #(right._name) + if left_has_name and not right_has_name then + right._name = left._name + elseif right_has_name and not left_has_name then + left._name = right._name + end + end + + local types_unknown = left._type == Type.UNKNOWN or right._type == Type.UNKNOWN + local types_are_nil = left._type == Type.NIL or right._type == Type.NIL + if not types_unknown + and left._type ~= right._type + and not types_are_nil then + report() + return false + elseif types_unknown then + if left._type == Type.UNKNOWN then + left._type = right._type + else + right._type = left._type + end + end + sync_names() + end + return true +end + +function Stack:len() + return #self._inner +end + +M.Stack = Stack + +------------------ +-- FunctionTree +------------------ + +local FunctionTree = { } +FunctionTree.__index = FunctionTree + +local function lookup(q, match) + local lookup_match = {} + for id, node in pairs(match) do + lookup_match[q.captures[id]] = node + end + return lookup_match +end + +-- A function that returns a function that returns a function +-- Used to cache the arg_query in one place +local make_check_function = (function() + local arg_query = tsq.parse_query('c', [[ + (argument_list + (call_expression + function: (_) @fname + arguments: (argument_list (_) @num .) @call .)) + (argument_list (number_literal) @num .) + ]]) + + return function(type) + return function(call, caller, stack, bufnr) + local _,last_arg = arg_query:iter_matches(call, bufnr)() + last_arg = lookup(arg_query, last_arg) + if last_arg.call then + local fname = tsq.get_node_text(last_arg.fname, bufnr) + if fname == "lua_upvalueindex" then + if not caller.up then return false end + local index = tonumber(tsq.get_node_text(last_arg.num, bufnr)) + local element = caller.up:get(index) + if not element then return false end + + if element._type == Type.UNKNOWN then + element._type = type + return true + end + + return element._type == type + else + -- Can't analyze this call: assume it is correct + return true + end + else + local index = tonumber(tsq.get_node_text(last_arg.num, bufnr)) + local element = stack:get(index) + if not element then return false end + + return element._type == type + end + end + end +end)() + +local default_contracts = { + lua_pushboolean = { + post = function(_call, _caller, stack, _bufnr) + stack:push(Element.anon(Type.BOOLEAN)) + end + }, + lua_pushnil = { + post = function(_call, _caller, stack, _bufnr) + stack:push(Element.nil_()) + end + }, + lua_newuserdata = { + post = function(_call, _caller, stack, _bufnr) + stack:push(Element.anon(Type.USERDATA)) + end + }, + luaL_checkstring = { + pre = make_check_function(Type.STRING) + } +} + +local State = {} +State.__index = State +local state_count = 0 + +function State.new(name, func, bufnr, steps) + local stack + if func.before then + stack = func.before:copy() + else + stack = Stack.new() + end + + state_count = state_count + 1; + return setmetatable({ + _id = state_count, + _root = func, + _root_name = name, + _stack = stack, + _bufnr = bufnr, + _steps = steps, + _ignored = {}, + _node_index = 1 + }, State) +end + +function State:split() + local ignored = {} + for node,_ in pairs(self._ignored) do + ignored[node] = true + end + + state_count = state_count + 1; + return setmetatable({ + _id = state_count, + _root = self._root, + _root_name = self._root_name, + _stack = self._stack:copy(), + _bufnr = self._bufnr, + _steps = self._steps, + _ignored = ignored, + _node_index = self._node_index + }, State) +end + +function State:_is_ignored(node) + for _,dest in ipairs(self._ignored) do + local n_sr, n_sc, n_er, n_ec = node:range() + local d_sr, d_sc, d_er, d_ec = dest:range() + + local start_fits = n_sr > d_sr or (d_sr == n_sr and n_sc > d_sc) + local end_fits = n_er < d_er or (n_er == d_er and n_ec < d_ec) + if start_fits and end_fits then return true end + end + return false +end + +function State:report(...) + print(" ", self._id, ":", ...) +end + +function State:step(functree) + local match = self._steps[self._node_index] + self._node_index = self._node_index + 1 + + if not match then return {}, false end + if self:_is_ignored(match.node) then return { self }, false end + + if match.call then + -- This is a function call + local fname = tsq.get_node_text(match.fname, self._bufnr) + local contract = functree._contracts[fname] + if contract then + if contract.pre + and not contract.pre(match.call, self._root, self._stack, self._bufnr) then + local call_row = match.fname:range() + self:report(string.format("Precondition violation on row: %d (%s)", call_row, fname)) + return {}, true + end + + if contract.post then + contract.post(self.call, self._root, self._stack, self._bufnr) + end + else + self:report("Missing contract for", fname) + local to_analyze = functree._functions[fname] + if to_analyze then + self:report("Found", fname, "in functions to analyse") + functree:analyse_function(fname, to_analyze) + contract = functree._contracts[fname] + if contract then + if contract.pre + and not contract.pre(match.call, self._root, self._stack, self._bufnr) then + local call_row = match.fname:range() + self:report(string.format("Precondition violation on row: %d (%s)", call_row, fname)) + return {}, true + end + + if contract.post then + contract.post(self.call, self._root, self._stack, self._bufnr) + end + else + self:report(fname, "did not generate any contract") + end + end + end + return { self }, false + elseif match.ret then + -- TODO: handle lua_error calls in ret + local ret_row = match.ret:range() + if self._root.after and not self._root.after:match(self._stack, true) then + self:report(string.format("Post condition not matching on row %d", ret_row)) + return {}, true + end + + return {}, false + elseif match.ifs then + -- Split the state in two + local right = self:split() + if match.right then + table.insert(self._ignored, match.right) + end + table.insert(right._ignored, match.left) + + return { self, right } + elseif match.comment then + local ctext = tsq.get_node_text(match.comment, self._bufnr) + local assert_stack = Stack.new(ctext) + local assert_row = match.comment:range() + self:report("Verifying assertion row", assert_row) + if self._stack:match(assert_stack, match.strict) then + return { self }, false + else + self:report(string.format("Assertion failed at row %d: \"%s\"", assert_row, ctext)) + return {}, true + end + end +end + +function State:print() + print("State", self._id) + self._stack:print() +end + +function FunctionTree:analyse_function(name, func) + print("Analysing", name) + func.done = true + local ANALYZE_QUERY = [[ + ;; Root call + (call_expression function: (identifier) @fname) @call @node + + (if_statement consequence: (_) @left alternative: (_)? @right) @ifs @node + + ((comment) @comment @node (#lua-match? @comment "^//sa ")) + ((comment) @comment @node @strict (#lua-match? @comment "^//sA ")) + + (return_statement) @ret @node]] + local analyze_query = tsq.parse_query('c', ANALYZE_QUERY) + local steps = {} + for _,match in analyze_query:iter_matches(func.body, self._bufnr) do + match = lookup(analyze_query, match) + table.insert(steps, match) + end + + -- Sort the steps in the order of appearance in the file + table.sort(steps, function(a,b) + local a_srow, a_scol = a.node:range() + local b_srow, b_scol = b.node:range() + + return a_srow < b_srow or (a_srow == b_srow and a_scol < b_scol) + end) + + local states = { State.new(name, func, self._bufnr, steps) } + local iterations = 0 + local errored = false + while #states > 0 and iterations < 10000 do + local s = table.remove(states) + local new_states, has_error = s:step(self) + errored = errored or has_error + for _,new in ipairs(new_states) do + table.insert(states, new) + end + iterations = iterations + 1 + end + + if not errored then + self._contracts[name] = { + pre = function(_call, _caller, stack, _bufnr) + if func.up then + return stack:match(func.up) + end + if func.before then + return stack:match(func.before) + end + end, + post = function(_call, _caller, stack, _bufnr) + if func.before then + for i=1,func.before:len() do + stack:pop() + end + end + if func.after then + stack:push_all(func.after) + end + end + } + end +end + +function FunctionTree:analyse(funcnames) + if funcnames then + for _,fname in ipairs(funcnames) do + local to_analyze = self._functions[fname] + if to_analyze and not to_analyze.done then + self:analyse_function(fname, to_analyze) + end + end + else + for name,func in pairs(self._functions) do + if not func.done then + self:analyse_function(name, func) + end + end + end +end + +function FunctionTree.new(buffer) + buffer = buffer or api.nvim_get_current_buf() + + local self = setmetatable({ + _bufnr = buffer, + _functions = {}, + _contracts = vim.deepcopy(default_contracts) + }, FunctionTree) + + local parser = ts.get_parser(buffer, 'c', {}) + local root = parser:parse()[1]:root() + + for _,match in query:iter_matches(root, buffer) do + -- Massage the match a little + + local new_match = lookup(query, match) + local fname = tsq.get_node_text(new_match.fname, buffer) + if not self._functions[fname] then + self._functions[fname] = { body = new_match.body, done = false } + end + for _,id in pairs { "up", "before", "after" } do + local node = new_match[id .. "_comment"] + if node then + local text = tsq.get_node_text(node, buffer) + self._functions[fname][id] = Stack.new(text) + end + end + end + return self +end + + +M.FunctionTree = FunctionTree + +return M