commit cd070614ce2c018cdae8bdd9a56b1824fc1a7cb9
parent 3433819c69837aafd9661414e57fc088cf4766dd
Author: Thomas Vigouroux <thomas.vigouroux@protonmail.com>
Date: Sun, 3 Apr 2022 15:11:31 +0200
feat: add static lua stack analyzer
Diffstat:
1 file changed, 560 insertions(+), 0 deletions(-)
diff --git a/lua/lua-stack-check.lua b/lua/lua-stack-check.lua
@@ -0,0 +1,560 @@
+local M = {}
+
+local api = vim.api
+local ts = vim.treesitter
+local tsq = require'vim.treesitter.query'
+
+local QUERY = [[
+;; Upvalues are defined before the function definition
+((comment) @up_comment
+ .
+ (_ (function_declarator declarator: (identifier) @fname)
+ body: (_)) @body
+ (#lua-match? @up_comment "^//s "))
+
+;; Stack arguments are defined in the first comment of the function body
+(function_definition
+ declarator: (_ declarator: (identifier) @fname)
+ body: (compound_statement . (comment) @before_comment)
+ (#lua-match? @before_comment "^//s ")) @body
+
+;; Return stack state is defined in the last comment of the function body
+(function_definition
+ declarator: (_ declarator: (identifier) @fname)
+ body: (compound_statement (comment) @after_comment . )
+ (#lua-match? @after_comment "^//s ")) @body
+]]
+local query = tsq.parse_query('c', QUERY)
+
+local function parse_stack(comment)
+ comment = comment:sub(5)
+ -- First collect all items
+ local items = {}
+ local init = 1
+ while true do
+ local start, stop = string.find(comment, "[a-z_]+", init)
+ if start then
+ table.insert(items, comment:sub(start, stop))
+ init = stop + 1
+ else
+ break
+ end
+ end
+ return items
+end
+
+------------------
+-- Type
+------------------
+
+local Type = {
+ UNKNOWN = 0,
+ STRING = 1,
+ INTEGER = 2,
+ BOOLEAN = 3,
+ TABLE = 4,
+ NIL = 5,
+ USERDATA = 6,
+}
+vim.tbl_add_reverse_lookup(Type)
+
+M.Type = Type
+
+------------------
+-- Element
+------------------
+
+local Element = {}
+Element.__index = Element
+
+function Element.new(name, type)
+ local self = setmetatable({
+ _name = name or "",
+ _type = type or Type.UNKNOWN
+ }, Element)
+ return self
+end
+
+function Element.nil_()
+ return Element.new("nil", Type.NIL)
+end
+
+function Element.anon(type)
+ return Element.new("", type)
+end
+
+M.Element = Element
+
+------------------
+-- Stack
+------------------
+
+local Stack = {}
+Stack.__index = Stack
+
+function Stack.new(items)
+ local inner = {}
+ if items then
+ for _, elem in pairs(parse_stack(items)) do
+ local type = Type.UNKNOWN
+ if elem == "nil" then
+ type = Type.NIL
+ end
+ table.insert(inner, Element.new(elem, type))
+ end
+ end
+ local self = setmetatable({
+ _inner = inner,
+ }, Stack)
+ return self
+end
+
+function Stack:copy()
+ return setmetatable(vim.deepcopy(self), Stack)
+end
+
+-- Stack manipulation
+
+function Stack:insert(...)
+ table.insert(self._inner, ...)
+end
+
+function Stack:push(element)
+ self:insert(element)
+end
+
+function Stack:push_all(other)
+ for _,elem in pairs(other._inner) do
+ self:push(elem)
+ end
+end
+
+function Stack:remove(...)
+ return table.remove(self._inner,...)
+end
+
+function Stack:pop()
+ return self:remove()
+end
+
+function Stack:get(index)
+ if index < 0 then
+ return self._inner[#self._inner + index + 1]
+ else
+ return self._inner[index]
+ end
+end
+
+function Stack:print()
+ local to_print = {}
+ for index,element in ipairs(self._inner) do
+ to_print[index] = string.format("%s: %s", element._name, Type[element._type])
+ end
+ print(vim.inspect(to_print))
+end
+
+function Stack:match(other, strict)
+ local short, long
+
+ local function report()
+ print("Stacks not matching...")
+ self:print()
+ other:print()
+ end
+
+ if strict then
+ short = self._inner
+ long = other._inner
+ if #short ~= #long then
+ report()
+ return false
+ end
+ else
+ if #self._inner > #other._inner then
+ short = other._inner
+ long = self._inner
+ else
+ short = self._inner
+ long = other._inner
+ end
+ end
+
+ for i=1,#short do
+ local left = short[i]
+ local right = long[#long - #short + i]
+
+ local function sync_names()
+ local left_has_name = #(left._name)
+ local right_has_name = #(right._name)
+ if left_has_name and not right_has_name then
+ right._name = left._name
+ elseif right_has_name and not left_has_name then
+ left._name = right._name
+ end
+ end
+
+ local types_unknown = left._type == Type.UNKNOWN or right._type == Type.UNKNOWN
+ local types_are_nil = left._type == Type.NIL or right._type == Type.NIL
+ if not types_unknown
+ and left._type ~= right._type
+ and not types_are_nil then
+ report()
+ return false
+ elseif types_unknown then
+ if left._type == Type.UNKNOWN then
+ left._type = right._type
+ else
+ right._type = left._type
+ end
+ end
+ sync_names()
+ end
+ return true
+end
+
+function Stack:len()
+ return #self._inner
+end
+
+M.Stack = Stack
+
+------------------
+-- FunctionTree
+------------------
+
+local FunctionTree = { }
+FunctionTree.__index = FunctionTree
+
+local function lookup(q, match)
+ local lookup_match = {}
+ for id, node in pairs(match) do
+ lookup_match[q.captures[id]] = node
+ end
+ return lookup_match
+end
+
+-- A function that returns a function that returns a function
+-- Used to cache the arg_query in one place
+local make_check_function = (function()
+ local arg_query = tsq.parse_query('c', [[
+ (argument_list
+ (call_expression
+ function: (_) @fname
+ arguments: (argument_list (_) @num .) @call .))
+ (argument_list (number_literal) @num .)
+ ]])
+
+ return function(type)
+ return function(call, caller, stack, bufnr)
+ local _,last_arg = arg_query:iter_matches(call, bufnr)()
+ last_arg = lookup(arg_query, last_arg)
+ if last_arg.call then
+ local fname = tsq.get_node_text(last_arg.fname, bufnr)
+ if fname == "lua_upvalueindex" then
+ if not caller.up then return false end
+ local index = tonumber(tsq.get_node_text(last_arg.num, bufnr))
+ local element = caller.up:get(index)
+ if not element then return false end
+
+ if element._type == Type.UNKNOWN then
+ element._type = type
+ return true
+ end
+
+ return element._type == type
+ else
+ -- Can't analyze this call: assume it is correct
+ return true
+ end
+ else
+ local index = tonumber(tsq.get_node_text(last_arg.num, bufnr))
+ local element = stack:get(index)
+ if not element then return false end
+
+ return element._type == type
+ end
+ end
+ end
+end)()
+
+local default_contracts = {
+ lua_pushboolean = {
+ post = function(_call, _caller, stack, _bufnr)
+ stack:push(Element.anon(Type.BOOLEAN))
+ end
+ },
+ lua_pushnil = {
+ post = function(_call, _caller, stack, _bufnr)
+ stack:push(Element.nil_())
+ end
+ },
+ lua_newuserdata = {
+ post = function(_call, _caller, stack, _bufnr)
+ stack:push(Element.anon(Type.USERDATA))
+ end
+ },
+ luaL_checkstring = {
+ pre = make_check_function(Type.STRING)
+ }
+}
+
+local State = {}
+State.__index = State
+local state_count = 0
+
+function State.new(name, func, bufnr, steps)
+ local stack
+ if func.before then
+ stack = func.before:copy()
+ else
+ stack = Stack.new()
+ end
+
+ state_count = state_count + 1;
+ return setmetatable({
+ _id = state_count,
+ _root = func,
+ _root_name = name,
+ _stack = stack,
+ _bufnr = bufnr,
+ _steps = steps,
+ _ignored = {},
+ _node_index = 1
+ }, State)
+end
+
+function State:split()
+ local ignored = {}
+ for node,_ in pairs(self._ignored) do
+ ignored[node] = true
+ end
+
+ state_count = state_count + 1;
+ return setmetatable({
+ _id = state_count,
+ _root = self._root,
+ _root_name = self._root_name,
+ _stack = self._stack:copy(),
+ _bufnr = self._bufnr,
+ _steps = self._steps,
+ _ignored = ignored,
+ _node_index = self._node_index
+ }, State)
+end
+
+function State:_is_ignored(node)
+ for _,dest in ipairs(self._ignored) do
+ local n_sr, n_sc, n_er, n_ec = node:range()
+ local d_sr, d_sc, d_er, d_ec = dest:range()
+
+ local start_fits = n_sr > d_sr or (d_sr == n_sr and n_sc > d_sc)
+ local end_fits = n_er < d_er or (n_er == d_er and n_ec < d_ec)
+ if start_fits and end_fits then return true end
+ end
+ return false
+end
+
+function State:report(...)
+ print(" ", self._id, ":", ...)
+end
+
+function State:step(functree)
+ local match = self._steps[self._node_index]
+ self._node_index = self._node_index + 1
+
+ if not match then return {}, false end
+ if self:_is_ignored(match.node) then return { self }, false end
+
+ if match.call then
+ -- This is a function call
+ local fname = tsq.get_node_text(match.fname, self._bufnr)
+ local contract = functree._contracts[fname]
+ if contract then
+ if contract.pre
+ and not contract.pre(match.call, self._root, self._stack, self._bufnr) then
+ local call_row = match.fname:range()
+ self:report(string.format("Precondition violation on row: %d (%s)", call_row, fname))
+ return {}, true
+ end
+
+ if contract.post then
+ contract.post(self.call, self._root, self._stack, self._bufnr)
+ end
+ else
+ self:report("Missing contract for", fname)
+ local to_analyze = functree._functions[fname]
+ if to_analyze then
+ self:report("Found", fname, "in functions to analyse")
+ functree:analyse_function(fname, to_analyze)
+ contract = functree._contracts[fname]
+ if contract then
+ if contract.pre
+ and not contract.pre(match.call, self._root, self._stack, self._bufnr) then
+ local call_row = match.fname:range()
+ self:report(string.format("Precondition violation on row: %d (%s)", call_row, fname))
+ return {}, true
+ end
+
+ if contract.post then
+ contract.post(self.call, self._root, self._stack, self._bufnr)
+ end
+ else
+ self:report(fname, "did not generate any contract")
+ end
+ end
+ end
+ return { self }, false
+ elseif match.ret then
+ -- TODO: handle lua_error calls in ret
+ local ret_row = match.ret:range()
+ if self._root.after and not self._root.after:match(self._stack, true) then
+ self:report(string.format("Post condition not matching on row %d", ret_row))
+ return {}, true
+ end
+
+ return {}, false
+ elseif match.ifs then
+ -- Split the state in two
+ local right = self:split()
+ if match.right then
+ table.insert(self._ignored, match.right)
+ end
+ table.insert(right._ignored, match.left)
+
+ return { self, right }
+ elseif match.comment then
+ local ctext = tsq.get_node_text(match.comment, self._bufnr)
+ local assert_stack = Stack.new(ctext)
+ local assert_row = match.comment:range()
+ self:report("Verifying assertion row", assert_row)
+ if self._stack:match(assert_stack, match.strict) then
+ return { self }, false
+ else
+ self:report(string.format("Assertion failed at row %d: \"%s\"", assert_row, ctext))
+ return {}, true
+ end
+ end
+end
+
+function State:print()
+ print("State", self._id)
+ self._stack:print()
+end
+
+function FunctionTree:analyse_function(name, func)
+ print("Analysing", name)
+ func.done = true
+ local ANALYZE_QUERY = [[
+ ;; Root call
+ (call_expression function: (identifier) @fname) @call @node
+
+ (if_statement consequence: (_) @left alternative: (_)? @right) @ifs @node
+
+ ((comment) @comment @node (#lua-match? @comment "^//sa "))
+ ((comment) @comment @node @strict (#lua-match? @comment "^//sA "))
+
+ (return_statement) @ret @node]]
+ local analyze_query = tsq.parse_query('c', ANALYZE_QUERY)
+ local steps = {}
+ for _,match in analyze_query:iter_matches(func.body, self._bufnr) do
+ match = lookup(analyze_query, match)
+ table.insert(steps, match)
+ end
+
+ -- Sort the steps in the order of appearance in the file
+ table.sort(steps, function(a,b)
+ local a_srow, a_scol = a.node:range()
+ local b_srow, b_scol = b.node:range()
+
+ return a_srow < b_srow or (a_srow == b_srow and a_scol < b_scol)
+ end)
+
+ local states = { State.new(name, func, self._bufnr, steps) }
+ local iterations = 0
+ local errored = false
+ while #states > 0 and iterations < 10000 do
+ local s = table.remove(states)
+ local new_states, has_error = s:step(self)
+ errored = errored or has_error
+ for _,new in ipairs(new_states) do
+ table.insert(states, new)
+ end
+ iterations = iterations + 1
+ end
+
+ if not errored then
+ self._contracts[name] = {
+ pre = function(_call, _caller, stack, _bufnr)
+ if func.up then
+ return stack:match(func.up)
+ end
+ if func.before then
+ return stack:match(func.before)
+ end
+ end,
+ post = function(_call, _caller, stack, _bufnr)
+ if func.before then
+ for i=1,func.before:len() do
+ stack:pop()
+ end
+ end
+ if func.after then
+ stack:push_all(func.after)
+ end
+ end
+ }
+ end
+end
+
+function FunctionTree:analyse(funcnames)
+ if funcnames then
+ for _,fname in ipairs(funcnames) do
+ local to_analyze = self._functions[fname]
+ if to_analyze and not to_analyze.done then
+ self:analyse_function(fname, to_analyze)
+ end
+ end
+ else
+ for name,func in pairs(self._functions) do
+ if not func.done then
+ self:analyse_function(name, func)
+ end
+ end
+ end
+end
+
+function FunctionTree.new(buffer)
+ buffer = buffer or api.nvim_get_current_buf()
+
+ local self = setmetatable({
+ _bufnr = buffer,
+ _functions = {},
+ _contracts = vim.deepcopy(default_contracts)
+ }, FunctionTree)
+
+ local parser = ts.get_parser(buffer, 'c', {})
+ local root = parser:parse()[1]:root()
+
+ for _,match in query:iter_matches(root, buffer) do
+ -- Massage the match a little
+
+ local new_match = lookup(query, match)
+ local fname = tsq.get_node_text(new_match.fname, buffer)
+ if not self._functions[fname] then
+ self._functions[fname] = { body = new_match.body, done = false }
+ end
+ for _,id in pairs { "up", "before", "after" } do
+ local node = new_match[id .. "_comment"]
+ if node then
+ local text = tsq.get_node_text(node, buffer)
+ self._functions[fname][id] = Stack.new(text)
+ end
+ end
+ end
+ return self
+end
+
+
+M.FunctionTree = FunctionTree
+
+return M