lua-stack-check.lua (14095B)
1 local M = {} 2 3 local api = vim.api 4 local ts = vim.treesitter 5 local tsq = require'vim.treesitter.query' 6 7 local QUERY = [[ 8 ;; Upvalues are defined before the function definition 9 ((comment) @up_comment 10 . 11 (_ (function_declarator declarator: (identifier) @fname) 12 body: (_)) @body 13 (#lua-match? @up_comment "^//s ")) 14 15 ;; Stack arguments are defined in the first comment of the function body 16 (function_definition 17 declarator: (_ declarator: (identifier) @fname) 18 body: (compound_statement . (comment) @before_comment) 19 (#lua-match? @before_comment "^//s ")) @body 20 21 ;; Return stack state is defined in the last comment of the function body 22 (function_definition 23 declarator: (_ declarator: (identifier) @fname) 24 body: (compound_statement (comment) @after_comment . ) 25 (#lua-match? @after_comment "^//s ")) @body 26 ]] 27 local query = tsq.parse_query('c', QUERY) 28 29 local function parse_stack(comment) 30 comment = comment:sub(5) 31 -- First collect all items 32 local items = {} 33 local init = 1 34 while true do 35 local start, stop = string.find(comment, "[a-z_]+", init) 36 if start then 37 table.insert(items, comment:sub(start, stop)) 38 init = stop + 1 39 else 40 break 41 end 42 end 43 return items 44 end 45 46 ------------------ 47 -- Type 48 ------------------ 49 50 local Type = { 51 UNKNOWN = 0, 52 STRING = 1, 53 INTEGER = 2, 54 BOOLEAN = 3, 55 TABLE = 4, 56 NIL = 5, 57 USERDATA = 6, 58 } 59 vim.tbl_add_reverse_lookup(Type) 60 61 M.Type = Type 62 63 ------------------ 64 -- Element 65 ------------------ 66 67 local Element = {} 68 Element.__index = Element 69 70 function Element.new(name, type) 71 local self = setmetatable({ 72 _name = name or "", 73 _type = type or Type.UNKNOWN 74 }, Element) 75 return self 76 end 77 78 function Element.nil_() 79 return Element.new("nil", Type.NIL) 80 end 81 82 function Element.anon(type) 83 return Element.new("", type) 84 end 85 86 M.Element = Element 87 88 ------------------ 89 -- Stack 90 ------------------ 91 92 local Stack = {} 93 Stack.__index = Stack 94 95 function Stack.new(items) 96 local inner = {} 97 if items then 98 for _, elem in pairs(parse_stack(items)) do 99 local type = Type.UNKNOWN 100 if elem == "nil" then 101 type = Type.NIL 102 end 103 table.insert(inner, Element.new(elem, type)) 104 end 105 end 106 local self = setmetatable({ 107 _inner = inner, 108 }, Stack) 109 return self 110 end 111 112 function Stack:copy() 113 return setmetatable(vim.deepcopy(self), Stack) 114 end 115 116 -- Stack manipulation 117 118 function Stack:insert(...) 119 table.insert(self._inner, ...) 120 end 121 122 function Stack:push(element) 123 self:insert(element) 124 end 125 126 function Stack:push_all(other) 127 for _,elem in pairs(other._inner) do 128 self:push(elem) 129 end 130 end 131 132 function Stack:remove(...) 133 return table.remove(self._inner,...) 134 end 135 136 function Stack:pop() 137 return self:remove() 138 end 139 140 function Stack:get(index) 141 if index < 0 then 142 return self._inner[#self._inner + index + 1] 143 else 144 return self._inner[index] 145 end 146 end 147 148 function Stack:print() 149 local to_print = {} 150 for index,element in ipairs(self._inner) do 151 to_print[index] = string.format("%s: %s", element._name, Type[element._type]) 152 end 153 print(vim.inspect(to_print)) 154 end 155 156 function Stack:match(other, strict) 157 local short, long 158 159 local function report() 160 print("Stacks not matching...") 161 self:print() 162 other:print() 163 end 164 165 if strict then 166 short = self._inner 167 long = other._inner 168 if #short ~= #long then 169 report() 170 return false 171 end 172 else 173 if #self._inner > #other._inner then 174 short = other._inner 175 long = self._inner 176 else 177 short = self._inner 178 long = other._inner 179 end 180 end 181 182 for i=1,#short do 183 local left = short[i] 184 local right = long[#long - #short + i] 185 186 local function sync_names() 187 local left_has_name = #(left._name) 188 local right_has_name = #(right._name) 189 if left_has_name and not right_has_name then 190 right._name = left._name 191 elseif right_has_name and not left_has_name then 192 left._name = right._name 193 end 194 end 195 196 local types_unknown = left._type == Type.UNKNOWN or right._type == Type.UNKNOWN 197 local types_are_nil = left._type == Type.NIL or right._type == Type.NIL 198 if not types_unknown 199 and left._type ~= right._type 200 and not types_are_nil then 201 report() 202 return false 203 elseif types_unknown then 204 if left._type == Type.UNKNOWN then 205 left._type = right._type 206 else 207 right._type = left._type 208 end 209 end 210 sync_names() 211 end 212 return true 213 end 214 215 function Stack:len() 216 return #self._inner 217 end 218 219 M.Stack = Stack 220 221 ------------------ 222 -- FunctionTree 223 ------------------ 224 225 local FunctionTree = { } 226 FunctionTree.__index = FunctionTree 227 228 local function lookup(q, match) 229 local lookup_match = {} 230 for id, node in pairs(match) do 231 lookup_match[q.captures[id]] = node 232 end 233 return lookup_match 234 end 235 236 -- A function that returns a function that returns a function 237 -- Used to cache the arg_query in one place 238 local make_check_function = (function() 239 local arg_query = tsq.parse_query('c', [[ 240 (argument_list 241 (call_expression 242 function: (_) @fname 243 arguments: (argument_list (_) @num .) @call .)) 244 (argument_list (number_literal) @num .) 245 ]]) 246 247 return function(type) 248 return function(call, caller, stack, bufnr) 249 local _,last_arg = arg_query:iter_matches(call, bufnr)() 250 last_arg = lookup(arg_query, last_arg) 251 if last_arg.call then 252 local fname = tsq.get_node_text(last_arg.fname, bufnr) 253 if fname == "lua_upvalueindex" then 254 if not caller.up then return false end 255 local index = tonumber(tsq.get_node_text(last_arg.num, bufnr)) 256 local element = caller.up:get(index) 257 if not element then return false end 258 259 if element._type == Type.UNKNOWN then 260 element._type = type 261 return true 262 end 263 264 return element._type == type 265 else 266 -- Can't analyze this call: assume it is correct 267 return true 268 end 269 else 270 local index = tonumber(tsq.get_node_text(last_arg.num, bufnr)) 271 local element = stack:get(index) 272 if not element then return false end 273 274 return element._type == type 275 end 276 end 277 end 278 end)() 279 280 local default_contracts = { 281 lua_pushboolean = { 282 post = function(_call, _caller, stack, _bufnr) 283 stack:push(Element.anon(Type.BOOLEAN)) 284 end 285 }, 286 lua_pushnil = { 287 post = function(_call, _caller, stack, _bufnr) 288 stack:push(Element.nil_()) 289 end 290 }, 291 lua_newuserdata = { 292 post = function(_call, _caller, stack, _bufnr) 293 stack:push(Element.anon(Type.USERDATA)) 294 end 295 }, 296 luaL_checkstring = { 297 pre = make_check_function(Type.STRING) 298 } 299 } 300 301 local State = {} 302 State.__index = State 303 local state_count = 0 304 305 function State.new(name, func, bufnr, steps) 306 local stack 307 if func.before then 308 stack = func.before:copy() 309 else 310 stack = Stack.new() 311 end 312 313 state_count = state_count + 1; 314 return setmetatable({ 315 _id = state_count, 316 _root = func, 317 _root_name = name, 318 _stack = stack, 319 _bufnr = bufnr, 320 _steps = steps, 321 _ignored = {}, 322 _node_index = 1 323 }, State) 324 end 325 326 function State:split() 327 local ignored = {} 328 for node,_ in pairs(self._ignored) do 329 ignored[node] = true 330 end 331 332 state_count = state_count + 1; 333 return setmetatable({ 334 _id = state_count, 335 _root = self._root, 336 _root_name = self._root_name, 337 _stack = self._stack:copy(), 338 _bufnr = self._bufnr, 339 _steps = self._steps, 340 _ignored = ignored, 341 _node_index = self._node_index 342 }, State) 343 end 344 345 function State:_is_ignored(node) 346 for _,dest in ipairs(self._ignored) do 347 local n_sr, n_sc, n_er, n_ec = node:range() 348 local d_sr, d_sc, d_er, d_ec = dest:range() 349 350 local start_fits = n_sr > d_sr or (d_sr == n_sr and n_sc > d_sc) 351 local end_fits = n_er < d_er or (n_er == d_er and n_ec < d_ec) 352 if start_fits and end_fits then return true end 353 end 354 return false 355 end 356 357 function State:report(...) 358 print(" ", self._id, ":", ...) 359 end 360 361 function State:step(functree) 362 local match = self._steps[self._node_index] 363 self._node_index = self._node_index + 1 364 365 if not match then return {}, false end 366 if self:_is_ignored(match.node) then return { self }, false end 367 368 if match.call then 369 -- This is a function call 370 local fname = tsq.get_node_text(match.fname, self._bufnr) 371 local contract = functree._contracts[fname] 372 if contract then 373 if contract.pre 374 and not contract.pre(match.call, self._root, self._stack, self._bufnr) then 375 local call_row = match.fname:range() 376 self:report(string.format("Precondition violation on row: %d (%s)", call_row, fname)) 377 return {}, true 378 end 379 380 if contract.post then 381 contract.post(self.call, self._root, self._stack, self._bufnr) 382 end 383 else 384 self:report("Missing contract for", fname) 385 local to_analyze = functree._functions[fname] 386 if to_analyze then 387 self:report("Found", fname, "in functions to analyse") 388 functree:analyse_function(fname, to_analyze) 389 contract = functree._contracts[fname] 390 if contract then 391 if contract.pre 392 and not contract.pre(match.call, self._root, self._stack, self._bufnr) then 393 local call_row = match.fname:range() 394 self:report(string.format("Precondition violation on row: %d (%s)", call_row, fname)) 395 return {}, true 396 end 397 398 if contract.post then 399 contract.post(self.call, self._root, self._stack, self._bufnr) 400 end 401 else 402 self:report(fname, "did not generate any contract") 403 end 404 end 405 end 406 return { self }, false 407 elseif match.ret then 408 -- TODO: handle lua_error calls in ret 409 local ret_row = match.ret:range() 410 if self._root.after and not self._root.after:match(self._stack, true) then 411 self:report(string.format("Post condition not matching on row %d", ret_row)) 412 return {}, true 413 end 414 415 return {}, false 416 elseif match.ifs then 417 -- Split the state in two 418 local right = self:split() 419 if match.right then 420 table.insert(self._ignored, match.right) 421 end 422 table.insert(right._ignored, match.left) 423 424 return { self, right } 425 elseif match.comment then 426 local ctext = tsq.get_node_text(match.comment, self._bufnr) 427 local assert_stack = Stack.new(ctext) 428 local assert_row = match.comment:range() 429 self:report("Verifying assertion row", assert_row) 430 if self._stack:match(assert_stack, match.strict) then 431 return { self }, false 432 else 433 self:report(string.format("Assertion failed at row %d: \"%s\"", assert_row, ctext)) 434 return {}, true 435 end 436 end 437 end 438 439 function State:print() 440 print("State", self._id) 441 self._stack:print() 442 end 443 444 function FunctionTree:analyse_function(name, func) 445 print("Analysing", name) 446 func.done = true 447 local ANALYZE_QUERY = [[ 448 ;; Root call 449 (call_expression function: (identifier) @fname) @call @node 450 451 (if_statement consequence: (_) @left alternative: (_)? @right) @ifs @node 452 453 ((comment) @comment @node (#lua-match? @comment "^//sa ")) 454 ((comment) @comment @node @strict (#lua-match? @comment "^//sA ")) 455 456 (return_statement) @ret @node]] 457 local analyze_query = tsq.parse_query('c', ANALYZE_QUERY) 458 local steps = {} 459 for _,match in analyze_query:iter_matches(func.body, self._bufnr) do 460 match = lookup(analyze_query, match) 461 table.insert(steps, match) 462 end 463 464 -- Sort the steps in the order of appearance in the file 465 table.sort(steps, function(a,b) 466 local a_srow, a_scol = a.node:range() 467 local b_srow, b_scol = b.node:range() 468 469 return a_srow < b_srow or (a_srow == b_srow and a_scol < b_scol) 470 end) 471 472 local states = { State.new(name, func, self._bufnr, steps) } 473 local iterations = 0 474 local errored = false 475 while #states > 0 and iterations < 10000 do 476 local s = table.remove(states) 477 local new_states, has_error = s:step(self) 478 errored = errored or has_error 479 for _,new in ipairs(new_states) do 480 table.insert(states, new) 481 end 482 iterations = iterations + 1 483 end 484 485 if not errored then 486 self._contracts[name] = { 487 pre = function(_call, _caller, stack, _bufnr) 488 if func.up then 489 return stack:match(func.up) 490 end 491 if func.before then 492 return stack:match(func.before) 493 end 494 end, 495 post = function(_call, _caller, stack, _bufnr) 496 if func.before then 497 for i=1,func.before:len() do 498 stack:pop() 499 end 500 end 501 if func.after then 502 stack:push_all(func.after) 503 end 504 end 505 } 506 end 507 end 508 509 function FunctionTree:analyse(funcnames) 510 if funcnames then 511 for _,fname in ipairs(funcnames) do 512 local to_analyze = self._functions[fname] 513 if to_analyze and not to_analyze.done then 514 self:analyse_function(fname, to_analyze) 515 end 516 end 517 else 518 for name,func in pairs(self._functions) do 519 if not func.done then 520 self:analyse_function(name, func) 521 end 522 end 523 end 524 end 525 526 function FunctionTree.new(buffer) 527 buffer = buffer or api.nvim_get_current_buf() 528 529 local self = setmetatable({ 530 _bufnr = buffer, 531 _functions = {}, 532 _contracts = vim.deepcopy(default_contracts) 533 }, FunctionTree) 534 535 local parser = ts.get_parser(buffer, 'c', {}) 536 local root = parser:parse()[1]:root() 537 538 for _,match in query:iter_matches(root, buffer) do 539 -- Massage the match a little 540 541 local new_match = lookup(query, match) 542 local fname = tsq.get_node_text(new_match.fname, buffer) 543 if not self._functions[fname] then 544 self._functions[fname] = { body = new_match.body, done = false } 545 end 546 for _,id in pairs { "up", "before", "after" } do 547 local node = new_match[id .. "_comment"] 548 if node then 549 local text = tsq.get_node_text(node, buffer) 550 self._functions[fname][id] = Stack.new(text) 551 end 552 end 553 end 554 return self 555 end 556 557 558 M.FunctionTree = FunctionTree 559 560 return M