comprehension.lua
local utils = require 'pl.utils'
local status,lb = pcall(require, "pl.luabalanced")
if not status then
lb = require 'luabalanced'
end
local math_max = math.max
local table_concat = table.concat
local ops = {
list = {init=' {} ', accum=' __result[#__result+1] = (%s) '},
table = {init=' {} ', accum=' local __k, __v = %s __result[__k] = __v '},
sum = {init=' 0 ', accum=' __result = __result + (%s) '},
min = {init=' nil ', accum=' local __tmp = %s ' ..
' if __result then if __tmp < __result then ' ..
'__result = __tmp end else __result = __tmp end '},
max = {init=' nil ', accum=' local __tmp = %s ' ..
' if __result then if __tmp > __result then ' ..
'__result = __tmp end else __result = __tmp end '},
}
local function parse_comprehension(expr)
local t = {}
local pos = 1
local opname
local tok, post = expr:match('^%s*([%a_][%w_]*)%s*%(()', pos)
local pose = #expr + 1
if tok then
local tok2, posb = lb.match_bracketed(expr, post-1)
assert(tok2, 'syntax error')
if expr:match('^%s*$', posb) then
opname = tok
pose = posb - 1
pos = post
end
end
opname = opname or "list"
local out; out, pos = lb.match_explist(expr, pos)
assert(out, "syntax error: missing expression list")
out = table_concat(out, ', ')
local fortypes = {}
local invarlists = {}
local invallists = {}
while 1 do
local post = expr:match('^%s*for%s+()', pos)
if not post then break end
pos = post
local iv; iv, pos = lb.match_namelist(expr, pos)
assert(#iv > 0, 'syntax error: zero variables')
for _,ident in ipairs(iv) do
assert(not ident:match'^__',
"identifier " .. ident .. " may not contain __ prefix")
end
invarlists[#invarlists+1] = iv
local fortype, post = expr:match('^(=)%s*()', pos)
if not fortype then fortype, post = expr:match('^(in)%s+()', pos) end
if fortype then
pos = post
local il; il, pos = lb.match_explist(expr, pos)
assert(#il > 0, 'syntax error: zero expressions')
assert(fortype ~= '=' or #il == 2 or #il == 3,
'syntax error: numeric for requires 2 or three expressions')
fortypes[#invarlists] = fortype
invallists[#invarlists] = il
else
fortypes[#invarlists] = false
invallists[#invarlists] = false
end
end
assert(#invarlists > 0, 'syntax error: missing "for" clause')
local preds = {}
while 1 do
local post = expr:match('^%s*if%s+()', pos)
if not post then break end
pos = post
local pred; pred, pos = lb.match_expression(expr, pos)
assert(pred, 'syntax error: predicated expression not found')
preds[#preds+1] = pred
end
local stmp = ''; lb.gsub(expr, function(u, sin) if u == 'e' then stmp = stmp .. ' ' .. sin .. ' ' end
end)
local max_param = 0; stmp:gsub('[%a_][%w_]*', function(s)
local s = s:match('^_(%d+)$')
if s then max_param = math_max(max_param, tonumber(s)) end
end)
if pos ~= pose then
assert(false, "syntax error: unrecognized " .. expr:sub(pos))
end
return out, fortypes, invarlists, invallists, preds, opname, max_param
end
local function code_comprehension(
out, fortypes, invarlists, invallists, preds, opname, max_param
)
local op = assert(ops[opname])
local code = op.accum:gsub('%%s', out)
for i=#preds,1,-1 do local pred = preds[i]
code = ' if ' .. pred .. ' then ' .. code .. ' end '
end
for i=#invarlists,1,-1 do
if not fortypes[i] then
local arrayname = '__in' .. i
local idx = '__idx' .. i
code =
' for ' .. idx .. ' = 1, #' .. arrayname .. ' do ' ..
' local ' .. invarlists[i][1] .. ' = ' .. arrayname .. '['..idx..'] ' ..
code .. ' end '
else
code =
' for ' ..
table_concat(invarlists[i], ', ') ..
' ' .. fortypes[i] .. ' ' ..
table_concat(invallists[i], ', ') ..
' do ' .. code .. ' end '
end
end
code = ' local __result = ( ' .. op.init .. ' ) ' .. code
return code
end
local function wrap_comprehension(code, ninputs, max_param, invallists, env)
assert(ninputs > 0)
local ts = {}
for i=1,max_param do
ts[#ts+1] = '_' .. i
end
for i=1,ninputs do
if not invallists[i] then
local name = '__in' .. i
ts[#ts+1] = name
end
end
if #ts > 0 then
code = ' local ' .. table_concat(ts, ', ') .. ' = ... ' .. code
end
code = code .. ' return __result '
local f, err = utils.load(code,'tmp','t',env)
if not f then assert(false, err .. ' with generated code ' .. code) end
return f
end
local function build_comprehension(expr, env)
local out, fortypes, invarlists, invallists, preds, opname, max_param
= parse_comprehension(expr)
local code = code_comprehension(
out, fortypes, invarlists, invallists, preds, opname, max_param)
local f = wrap_comprehension(code, #invarlists, max_param, invallists, env)
return f
end
local function new(env)
if not env then
env = utils.getfenv(2)
end
local mt = {}
local cache = setmetatable({}, mt)
function mt:__index(expr)
local f = build_comprehension(expr, env)
self[expr] = f return f
end
mt.__call = mt.__index
cache.new = new
return cache
end
local comprehension = {}
comprehension.new = new
return comprehension