nvim-config

Log | Files | Refs | Submodules | README

lua-stack-check.lua (14095B)


      1 local M = {}
      2 
      3 local api = vim.api
      4 local ts = vim.treesitter
      5 local tsq = require'vim.treesitter.query'
      6 
      7 local QUERY = [[
      8 ;; Upvalues are defined before the function definition
      9 ((comment) @up_comment
     10  .
     11  (_ (function_declarator declarator: (identifier) @fname)
     12   body: (_)) @body
     13  (#lua-match? @up_comment "^//s "))
     14 
     15 ;; Stack arguments are defined in the first comment of the function body
     16 (function_definition
     17  declarator: (_ declarator: (identifier) @fname)
     18  body: (compound_statement . (comment) @before_comment)
     19  (#lua-match? @before_comment "^//s ")) @body
     20 
     21 ;; Return stack state is defined in the last comment of the function body
     22 (function_definition
     23  declarator: (_ declarator: (identifier) @fname)
     24  body: (compound_statement (comment) @after_comment . )
     25  (#lua-match? @after_comment "^//s ")) @body
     26 ]]
     27 local query = tsq.parse_query('c', QUERY)
     28 
     29 local function parse_stack(comment)
     30   comment = comment:sub(5)
     31   -- First collect all items
     32   local items = {}
     33   local init = 1
     34   while true do
     35     local start, stop = string.find(comment, "[a-z_]+", init)
     36     if start then
     37       table.insert(items, comment:sub(start, stop))
     38       init = stop + 1
     39     else
     40       break
     41     end
     42   end
     43   return items
     44 end
     45 
     46 ------------------
     47 -- Type
     48 ------------------
     49 
     50 local Type = {
     51   UNKNOWN = 0,
     52   STRING = 1,
     53   INTEGER = 2,
     54   BOOLEAN = 3,
     55   TABLE = 4,
     56   NIL = 5,
     57   USERDATA = 6,
     58 }
     59 vim.tbl_add_reverse_lookup(Type)
     60 
     61 M.Type = Type
     62 
     63 ------------------
     64 -- Element
     65 ------------------
     66 
     67 local Element = {}
     68 Element.__index = Element
     69 
     70 function Element.new(name, type)
     71   local self = setmetatable({
     72     _name = name or "",
     73     _type = type or Type.UNKNOWN
     74   }, Element)
     75   return self
     76 end
     77 
     78 function Element.nil_()
     79   return Element.new("nil", Type.NIL)
     80 end
     81 
     82 function Element.anon(type)
     83   return Element.new("", type)
     84 end
     85 
     86 M.Element = Element
     87 
     88 ------------------
     89 -- Stack
     90 ------------------
     91 
     92 local Stack = {}
     93 Stack.__index = Stack
     94 
     95 function Stack.new(items)
     96   local inner = {}
     97   if items then
     98     for _, elem in pairs(parse_stack(items)) do
     99       local type = Type.UNKNOWN
    100       if elem == "nil" then
    101         type = Type.NIL
    102       end
    103       table.insert(inner, Element.new(elem, type))
    104     end
    105   end
    106   local self = setmetatable({
    107     _inner = inner,
    108   }, Stack)
    109   return self
    110 end
    111 
    112 function Stack:copy()
    113   return setmetatable(vim.deepcopy(self), Stack)
    114 end
    115 
    116 -- Stack manipulation
    117 
    118 function Stack:insert(...)
    119   table.insert(self._inner, ...)
    120 end
    121 
    122 function Stack:push(element)
    123   self:insert(element)
    124 end
    125 
    126 function Stack:push_all(other)
    127   for _,elem in pairs(other._inner) do
    128     self:push(elem)
    129   end
    130 end
    131 
    132 function Stack:remove(...)
    133   return table.remove(self._inner,...)
    134 end
    135 
    136 function Stack:pop()
    137   return self:remove()
    138 end
    139 
    140 function Stack:get(index)
    141   if index < 0 then
    142     return self._inner[#self._inner + index + 1]
    143   else
    144     return self._inner[index]
    145   end
    146 end
    147 
    148 function Stack:print()
    149   local to_print = {}
    150   for index,element in ipairs(self._inner) do
    151     to_print[index] = string.format("%s: %s", element._name, Type[element._type])
    152   end
    153   print(vim.inspect(to_print))
    154 end
    155 
    156 function Stack:match(other, strict)
    157   local short, long
    158 
    159   local function report()
    160     print("Stacks not matching...")
    161     self:print()
    162     other:print()
    163   end
    164 
    165   if strict then
    166     short = self._inner
    167     long = other._inner
    168     if #short ~= #long then
    169       report()
    170       return false
    171     end
    172   else
    173     if #self._inner > #other._inner then
    174       short = other._inner
    175       long = self._inner
    176     else
    177       short = self._inner
    178       long = other._inner
    179     end
    180   end
    181 
    182   for i=1,#short do
    183     local left = short[i]
    184     local right = long[#long - #short + i]
    185 
    186     local function sync_names()
    187       local left_has_name = #(left._name)
    188       local right_has_name = #(right._name)
    189       if left_has_name and not right_has_name then
    190         right._name = left._name
    191       elseif right_has_name and not left_has_name then
    192         left._name = right._name
    193       end
    194     end
    195 
    196     local types_unknown = left._type == Type.UNKNOWN or right._type == Type.UNKNOWN
    197     local types_are_nil = left._type == Type.NIL or right._type == Type.NIL
    198     if not types_unknown
    199         and left._type ~= right._type
    200         and not types_are_nil then
    201       report()
    202       return false
    203     elseif types_unknown then
    204       if left._type == Type.UNKNOWN then
    205         left._type = right._type
    206       else
    207         right._type = left._type
    208       end
    209     end
    210     sync_names()
    211   end
    212   return true
    213 end
    214 
    215 function Stack:len()
    216   return #self._inner
    217 end
    218 
    219 M.Stack = Stack
    220 
    221 ------------------
    222 -- FunctionTree
    223 ------------------
    224 
    225 local FunctionTree = { }
    226 FunctionTree.__index = FunctionTree
    227 
    228 local function lookup(q, match)
    229   local lookup_match = {}
    230   for id, node in pairs(match) do
    231     lookup_match[q.captures[id]] = node
    232   end
    233   return lookup_match
    234 end
    235 
    236 -- A function that returns a function that returns a function
    237 -- Used to cache the arg_query in one place
    238 local make_check_function = (function()
    239       local arg_query = tsq.parse_query('c', [[
    240         (argument_list
    241           (call_expression
    242             function: (_) @fname
    243             arguments: (argument_list (_) @num .) @call .))
    244         (argument_list (number_literal) @num .)
    245       ]])
    246 
    247       return function(type)
    248         return function(call, caller, stack, bufnr)
    249           local _,last_arg = arg_query:iter_matches(call, bufnr)()
    250           last_arg = lookup(arg_query, last_arg)
    251           if last_arg.call then
    252             local fname = tsq.get_node_text(last_arg.fname, bufnr)
    253             if fname == "lua_upvalueindex" then
    254               if not caller.up then return false end
    255               local index = tonumber(tsq.get_node_text(last_arg.num, bufnr))
    256               local element = caller.up:get(index)
    257               if not element then return false end
    258 
    259               if element._type == Type.UNKNOWN then
    260                 element._type = type
    261                 return true
    262               end
    263 
    264               return element._type == type
    265             else
    266               -- Can't analyze this call: assume it is correct
    267               return true
    268             end
    269           else
    270             local index = tonumber(tsq.get_node_text(last_arg.num, bufnr))
    271             local element = stack:get(index)
    272             if not element then return false end
    273 
    274             return element._type == type
    275           end
    276         end
    277       end
    278 end)()
    279 
    280 local default_contracts = {
    281   lua_pushboolean = {
    282     post = function(_call, _caller, stack, _bufnr)
    283       stack:push(Element.anon(Type.BOOLEAN))
    284     end
    285   },
    286   lua_pushnil = {
    287     post = function(_call, _caller, stack, _bufnr)
    288       stack:push(Element.nil_())
    289     end
    290   },
    291   lua_newuserdata = {
    292     post = function(_call, _caller, stack, _bufnr)
    293       stack:push(Element.anon(Type.USERDATA))
    294     end
    295   },
    296   luaL_checkstring = {
    297     pre = make_check_function(Type.STRING)
    298   }
    299 }
    300 
    301 local State = {}
    302 State.__index = State
    303 local state_count = 0
    304 
    305 function State.new(name, func, bufnr, steps)
    306   local stack
    307   if func.before then
    308     stack = func.before:copy()
    309   else
    310     stack = Stack.new()
    311   end
    312 
    313   state_count = state_count + 1;
    314   return setmetatable({
    315     _id = state_count,
    316     _root = func,
    317     _root_name = name,
    318     _stack = stack,
    319     _bufnr = bufnr,
    320     _steps = steps,
    321     _ignored = {},
    322     _node_index = 1
    323   }, State)
    324 end
    325 
    326 function State:split()
    327   local ignored = {}
    328   for node,_ in pairs(self._ignored) do
    329     ignored[node] = true
    330   end
    331 
    332   state_count = state_count + 1;
    333   return setmetatable({
    334     _id = state_count,
    335     _root = self._root,
    336     _root_name = self._root_name,
    337     _stack = self._stack:copy(),
    338     _bufnr = self._bufnr,
    339     _steps = self._steps,
    340     _ignored = ignored,
    341     _node_index = self._node_index
    342   }, State)
    343 end
    344 
    345 function State:_is_ignored(node)
    346   for _,dest in ipairs(self._ignored) do
    347     local n_sr, n_sc, n_er, n_ec = node:range()
    348     local d_sr, d_sc, d_er, d_ec = dest:range()
    349 
    350     local start_fits = n_sr > d_sr or (d_sr == n_sr and n_sc > d_sc)
    351     local end_fits = n_er < d_er or (n_er == d_er and n_ec < d_ec)
    352     if start_fits and end_fits then return true end
    353   end
    354   return false
    355 end
    356 
    357 function State:report(...)
    358   print("  ", self._id, ":", ...)
    359 end
    360 
    361 function State:step(functree)
    362   local match = self._steps[self._node_index]
    363   self._node_index = self._node_index + 1
    364 
    365   if not match then return {}, false end
    366   if self:_is_ignored(match.node) then return { self }, false end
    367 
    368   if match.call then
    369     -- This is a function call
    370     local fname = tsq.get_node_text(match.fname, self._bufnr)
    371     local contract = functree._contracts[fname]
    372     if contract then
    373       if contract.pre
    374         and not contract.pre(match.call, self._root, self._stack, self._bufnr) then
    375         local call_row = match.fname:range()
    376         self:report(string.format("Precondition violation on row: %d (%s)", call_row, fname))
    377         return {}, true
    378       end
    379 
    380       if contract.post then
    381         contract.post(self.call, self._root, self._stack, self._bufnr)
    382       end
    383     else
    384       self:report("Missing contract for", fname)
    385       local to_analyze = functree._functions[fname]
    386       if to_analyze then
    387         self:report("Found", fname, "in functions to analyse")
    388         functree:analyse_function(fname, to_analyze)
    389         contract = functree._contracts[fname]
    390         if contract then
    391           if contract.pre
    392             and not contract.pre(match.call, self._root, self._stack, self._bufnr) then
    393             local call_row = match.fname:range()
    394             self:report(string.format("Precondition violation on row: %d (%s)", call_row, fname))
    395             return {}, true
    396           end
    397 
    398           if contract.post then
    399             contract.post(self.call, self._root, self._stack, self._bufnr)
    400           end
    401         else
    402           self:report(fname, "did not generate any contract")
    403         end
    404       end
    405     end
    406     return { self }, false
    407   elseif match.ret then
    408     -- TODO: handle lua_error calls in ret
    409     local ret_row = match.ret:range()
    410     if self._root.after and not self._root.after:match(self._stack, true) then
    411       self:report(string.format("Post condition not matching on row %d", ret_row))
    412       return {}, true
    413     end
    414 
    415     return {}, false
    416   elseif match.ifs then
    417     -- Split the state in two
    418     local right = self:split()
    419     if match.right then
    420       table.insert(self._ignored, match.right)
    421     end
    422     table.insert(right._ignored, match.left)
    423 
    424     return { self, right }
    425   elseif match.comment then
    426     local ctext = tsq.get_node_text(match.comment, self._bufnr)
    427     local assert_stack = Stack.new(ctext)
    428     local assert_row = match.comment:range()
    429     self:report("Verifying assertion row", assert_row)
    430     if self._stack:match(assert_stack, match.strict) then
    431       return { self }, false
    432     else
    433       self:report(string.format("Assertion failed at row %d: \"%s\"", assert_row, ctext))
    434       return {}, true
    435     end
    436   end
    437 end
    438 
    439 function State:print()
    440   print("State", self._id)
    441   self._stack:print()
    442 end
    443 
    444 function FunctionTree:analyse_function(name, func)
    445   print("Analysing", name)
    446   func.done = true
    447   local ANALYZE_QUERY = [[
    448   ;; Root call
    449   (call_expression function: (identifier) @fname) @call @node
    450 
    451   (if_statement consequence: (_) @left alternative: (_)? @right) @ifs @node
    452 
    453   ((comment) @comment @node (#lua-match? @comment "^//sa "))
    454   ((comment) @comment @node @strict (#lua-match? @comment "^//sA "))
    455 
    456   (return_statement) @ret @node]]
    457   local analyze_query = tsq.parse_query('c', ANALYZE_QUERY)
    458   local steps = {}
    459   for _,match in analyze_query:iter_matches(func.body, self._bufnr) do
    460     match = lookup(analyze_query, match)
    461     table.insert(steps, match)
    462   end
    463 
    464   -- Sort the steps in the order of appearance in the file
    465   table.sort(steps, function(a,b)
    466     local a_srow, a_scol = a.node:range()
    467     local b_srow, b_scol = b.node:range()
    468 
    469     return a_srow < b_srow or (a_srow == b_srow and a_scol < b_scol)
    470   end)
    471 
    472   local states = { State.new(name, func, self._bufnr, steps) }
    473   local iterations = 0
    474   local errored = false
    475   while #states > 0 and iterations < 10000 do
    476     local s = table.remove(states)
    477     local new_states, has_error = s:step(self)
    478     errored = errored or has_error
    479     for _,new in ipairs(new_states) do
    480       table.insert(states, new)
    481     end
    482     iterations = iterations + 1
    483   end
    484 
    485   if not errored then
    486     self._contracts[name] = {
    487       pre = function(_call, _caller, stack, _bufnr)
    488         if func.up then
    489           return stack:match(func.up)
    490         end
    491         if func.before then
    492           return stack:match(func.before)
    493         end
    494       end,
    495       post = function(_call, _caller, stack, _bufnr)
    496         if func.before then
    497           for i=1,func.before:len() do
    498             stack:pop()
    499           end
    500         end
    501         if func.after then
    502           stack:push_all(func.after)
    503         end
    504       end
    505     }
    506   end
    507 end
    508 
    509 function FunctionTree:analyse(funcnames)
    510   if funcnames then
    511     for _,fname in ipairs(funcnames) do
    512       local to_analyze = self._functions[fname]
    513       if to_analyze and not to_analyze.done then
    514         self:analyse_function(fname, to_analyze)
    515       end
    516     end
    517   else
    518     for name,func in pairs(self._functions) do
    519       if not func.done then
    520         self:analyse_function(name, func)
    521       end
    522     end
    523   end
    524 end
    525 
    526 function FunctionTree.new(buffer)
    527   buffer = buffer or api.nvim_get_current_buf()
    528 
    529   local self = setmetatable({
    530     _bufnr = buffer,
    531     _functions = {},
    532     _contracts = vim.deepcopy(default_contracts)
    533   }, FunctionTree)
    534 
    535   local parser = ts.get_parser(buffer, 'c', {})
    536   local root = parser:parse()[1]:root()
    537 
    538   for _,match in query:iter_matches(root, buffer) do
    539     -- Massage the match a little
    540 
    541     local new_match = lookup(query, match)
    542     local fname = tsq.get_node_text(new_match.fname, buffer)
    543     if not self._functions[fname] then
    544       self._functions[fname] = { body = new_match.body, done = false }
    545     end
    546     for _,id in pairs { "up", "before", "after" } do
    547       local node = new_match[id .. "_comment"]
    548       if node then
    549         local text = tsq.get_node_text(node, buffer)
    550         self._functions[fname][id] = Stack.new(text)
    551       end
    552     end
    553   end
    554   return self
    555 end
    556 
    557 
    558 M.FunctionTree = FunctionTree
    559 
    560 return M