data.lua
local utils = require 'pl.utils'
local _DEBUG = rawget(_G,'_DEBUG')
local patterns,function_arg,usplit,array_tostring = utils.patterns,utils.function_arg,utils.split,utils.array_tostring
local append,concat = table.insert,table.concat
local gsub = string.gsub
local io = io
local _G,print,type,tonumber,ipairs,setmetatable,pcall,error = _G,print,type,tonumber,ipairs,setmetatable,pcall,error
local data = {}
local parse_select
local function count(s,chr)
chr = utils.escape(chr)
local _,cnt = s:gsub(chr,' ')
return cnt
end
local function rstrip(s)
return (s:gsub('%s+$',''))
end
local function strip (s)
return (rstrip(s):gsub('^%s*',''))
end
local function make_list(l)
return setmetatable(l,utils.stdmt.List)
end
local function map(fun,t)
local res = {}
for i = 1,#t do
res[i] = fun(t[i])
end
return res
end
local function split(line,delim,csv,n)
local massage
if csv and line:match '"' then
line = line:gsub('"([^"]+)"',function(str)
local s,cnt = str:gsub(',','\001')
if cnt > 0 then massage = true end
return s
end)
if massage then
massage = function(s) return (s:gsub('\001',',')) end
end
end
local res = (usplit(line,delim,false,n))
if csv then
if massage then res = map(massage,res) end
if line:match ',$' then append(res,'') end
end
return make_list(res)
end
local function find(t,v)
for i = 1,#t do
if v == t[i] then return i end
end
end
local DataMT = {
column_by_name = function(self,name)
if type(name) == 'number' then
name = '$'..name
end
local arr = {}
for res in data.query(self,name) do
append(arr,res)
end
return make_list(arr)
end,
copy_select = function(self,condn)
condn = parse_select(condn,self)
local iter = data.query(self,condn)
local res = {}
local row = make_list{iter()}
while #row > 0 do
append(res,row)
row = make_list{iter()}
end
res.delim = self.delim
return data.new(res,split(condn.fields,','))
end,
column_names = function(self)
return self.fieldnames
end,
}
local array2d
DataMT.__index = function(self,name)
local f = DataMT[name]
if f then return f end
if not array2d then
array2d = require 'pl.array2d'
end
return array2d[name]
end
local delims = {',','\t',' ',';'}
local function guess_delim (line)
if line=='' then return ' ' end
for _,delim in ipairs(delims) do
if count(line,delim) > 0 then
return delim == ' ' and '%s+' or delim
end
end
return ' '
end
local function open_file (f,mode)
local opened, err
local reading = mode == 'r'
if type(f) == 'string' then
if f == 'stdin' then
f = io.stdin
elseif f == 'stdout' then
f = io.stdout
else
f,err = io.open(f,mode)
if not f then return nil,err end
opened = true
end
end
if f and ((reading and not f.read) or (not reading and not f.write)) then
return nil, "not a file-like object"
end
return f,nil,opened
end
local function all_n ()
end
function data.read(file,cnfg)
local err,opened,count,line,csv
local D = {}
if not cnfg then cnfg = {} end
local f,err,opened = open_file(file,'r')
if not f then return nil, err end
local thousands_dot = cnfg.thousands_dot
local tonumber = tonumber
local function try_number(x)
if thousands_dot then x = x:gsub('%.(...)','%1') end
local v = tonumber(x)
if v == nil then return nil,"not a number" end
return v
end
csv = cnfg.csv
if csv then cnfg.delim = ',' end
count = 1
line = f:read()
if not line then return nil, "empty file" end
D.delim = cnfg.delim and cnfg.delim or guess_delim(line)
local delim = D.delim
local conversion
local numfields = {}
local function append_conversion (idx,conv)
conversion = conversion or {}
append(numfields,idx)
append(conversion,conv)
end
if cnfg.numfields then
for _,n in ipairs(cnfg.numfields) do append_conversion(n,try_number) end
end
local stripper
if delim == '%s+' and line:find(delim) == 1 then
stripper = function(s) return s:gsub('^%s+','') end
line = stripper(line)
end
if not cnfg.fieldnames then
local fields,nums
fields = split(line,delim,csv)
if not cnfg.convert then
nums = map(tonumber,fields)
if #nums == #fields then append(D,nums) for i = 1,#nums do append_conversion(i,try_number) end
else nums = nil
end
else for idx,conv in pairs(cnfg.convert) do append_conversion(idx,conv) end
end
if nums == nil then
cnfg.fieldnames = fields
end
line = f:read()
count = count + 1
if stripper then line = stripper(line) end
elseif type(cnfg.fieldnames) == 'string' then
cnfg.fieldnames = split(cnfg.fieldnames,delim,csv)
end
local nfields
if cnfg.fieldnames then
D.fieldnames = cnfg.fieldnames
if cnfg.last_field_collect then
nfields = #D.fieldnames
end
if not cnfg.no_convert then
local fields = split(line,D.delim,csv,nfields)
for i = 1,#fields do
if not find(numfields,i) and tonumber(fields[i]) then
append_conversion(i,try_number)
end
end
end
end
while line do
if not line:find ('^%s*$') then if stripper then line = stripper(line) end
local fields = split(line,delim,csv,nfields)
if conversion then for k = 1,#numfields do
local i,conv = numfields[k],conversion[k]
local val,err = conv(fields[i])
if val == nil then
return nil, err..": "..fields[i].." at line "..count
else
fields[i] = val
end
end
end
append(D,fields)
end
line = f:read()
count = count + 1
end
if opened then f:close() end
if delim == '%s+' then D.delim = ' ' end
if not D.fieldnames then D.fieldnames = {} end
return data.new(D)
end
local function write_row (data,f,row,delim)
data.temp = array_tostring(row,data.temp)
f:write(concat(data.temp,delim),'\n')
end
function DataMT:write_row(f,row)
write_row(self,f,row,self.delim)
end
function data.write (data,file,fieldnames,delim)
local f,err,opened = open_file(file,'w')
if not f then return nil, err end
if fieldnames and #fieldnames > 0 then
f:write(concat(data.fieldnames,delim),'\n')
end
delim = delim or '\t'
for i = 1,#data do
write_row(data,f,data[i],delim)
end
if opened then f:close() end
end
function DataMT:write(file)
data.write(self,file,self.fieldnames,self.delim)
end
local function massage_fieldnames (fields,copy)
for i = 1,#fields do
local f = strip(fields[i])
copy[i] = f
fields[i] = f:gsub('%W','_')
end
end
function data.new (d,fieldnames)
d.fieldnames = d.fieldnames or fieldnames or ''
if not d.delim and type(d.fieldnames) == 'string' then
d.delim = guess_delim(d.fieldnames)
d.fieldnames = split(d.fieldnames,d.delim)
end
d.fieldnames = make_list(d.fieldnames)
d.original_fieldnames = {}
massage_fieldnames(d.fieldnames,d.original_fieldnames)
setmetatable(d,DataMT)
return d
end
local sorted_query = [[
return function (t)
local i = 0
local v
local ls = {}
for i,v in ipairs(t) do
if CONDITION then
ls[#ls+1] = v
end
end
table.sort(ls,function(v1,v2)
return SORT_EXPR
end)
local n = #ls
return function()
i = i + 1
v = ls[i]
if i > n then return end
return FIELDLIST
end
end
]]
local simple_query = [[
return function (t)
local n = #t
local i = 0
local v
return function()
repeat
i = i + 1
v = t[i]
until i > n or CONDITION
if i > n then return end
return FIELDLIST
end
end
]]
local function is_string (s)
return type(s) == 'string'
end
local field_error
local function fieldnames_as_string (data)
return concat(data.fieldnames,',')
end
local function massage_fields(data,f)
local idx
if f:find '^%d+$' then
idx = tonumber(f)
else
idx = find(data.fieldnames,f)
end
if idx then
return 'v['..idx..']'
else
field_error = f..' not found in '..fieldnames_as_string(data)
return f
end
end
local function process_select (data,parms)
local res,ret
field_error = nil
local fields = parms.fields
local numfields = fields:find '%$' or #data.fieldnames == 0
if fields:find '^%s*%*%s*' then
if not numfields then
fields = fieldnames_as_string(data)
else
local ncol = #data[1]
fields = {}
for i = 1,ncol do append(fields,'$'..i) end
fields = concat(fields,',')
end
end
local idpat = patterns.IDEN
if numfields then
idpat = '%$(%d+)'
else
fields = rstrip(fields):gsub('[^,%w]','_')
end
local massage_fields = utils.bind1(massage_fields,data)
ret = gsub(fields,idpat,massage_fields)
if field_error then return nil,field_error end
parms.fields = fields
parms.proc_fields = ret
parms.where = parms.where or 'true'
if is_string(parms.where) then
parms.where = gsub(parms.where,idpat,massage_fields)
field_error = nil
end
return true
end
parse_select = function(s,data)
local endp
local parms = {}
local w1,w2 = s:find('where ')
local s1,s2 = s:find('sort by ')
if w1 then endp = (s1 or 0)-1
parms.where = s:sub(w2+1,endp)
end
if s1 then parms.sort_by = s:sub(s2+1)
end
endp = (w1 or s1 or 0)-1
parms.fields = s:sub(1,endp)
local status,err = process_select(data,parms)
if not status then return nil,err
else return parms end
end
function data.query(data,condn,context,return_row)
local err
if is_string(condn) then
condn,err = parse_select(condn,data)
if not condn then return nil,err end
elseif type(condn) == 'table' then
if type(condn.fields) == 'table' then
condn.fields = concat(condn.fields,',')
end
if not condn.proc_fields then
local status,err = process_select(data,condn)
if not status then return nil,err end
end
else
return nil, "condition must be a string or a table"
end
local query, k
if condn.sort_by then query = sorted_query
else
query = simple_query
end
local fields = condn.proc_fields or condn.fields
if return_row then
fields = '{'..fields..'}'
end
query,k = query:gsub('FIELDLIST',fields)
if is_string(condn.where) then
query = query:gsub('CONDITION',condn.where)
condn.where = nil
else
query = query:gsub('CONDITION','_condn(v)')
condn.where = function_arg(0,condn.where,'condition.where must be callable')
end
if condn.sort_by then
local expr,sort_var,sort_dir
local sort_by = condn.sort_by
local i1,i2 = sort_by:find('%s+')
if i1 then
sort_var,sort_dir = sort_by:sub(1,i1-1),sort_by:sub(i2+1)
else
sort_var = sort_by
sort_dir = 'asc'
end
if sort_var:match '^%$' then sort_var = sort_var:sub(2) end
sort_var = massage_fields(data,sort_var)
if field_error then return nil,field_error end
if sort_dir == 'asc' then
sort_dir = '<'
else
sort_dir = '>'
end
expr = ('%s %s %s'):format(sort_var:gsub('v','v1'),sort_dir,sort_var:gsub('v','v2'))
query = query:gsub('SORT_EXPR',expr)
end
if condn.where then
query = 'return function(_condn) '..query..' end'
end
if _DEBUG then print(query) end
local fn,err = utils.load(query,'tmp')
if not fn then return nil,err end
fn = fn() if condn.where then
fn = fn(condn.where)
end
local qfun = fn(data)
if context then
append(context,_G)
local lookup = {}
utils.setfenv(qfun,lookup)
setmetatable(lookup,{
__index = function(tbl,key)
for k,t in ipairs(context) do
if t[key] then return t[key] end
end
end
})
end
return qfun
end
DataMT.select = data.query
DataMT.select_row = function(d,condn,context)
return data.query(d,condn,context,true)
end
function data.filter (Q,infile,outfile,dont_fail)
local err
local d = data.read(infile or 'stdin')
local out = open_file(outfile or 'stdout')
local iter,err = d:select(Q)
local delim = d.delim
if not iter then
err = 'error: '..err
if dont_fail then
return nil,err
else
utils.quit(1,err)
end
end
while true do
local res = {iter()}
if #res == 0 then break end
out:write(concat(res,delim),'\n')
end
end
return data