From e542e294ff62fd7fd05817e7b08a44f625b58a3d Mon Sep 17 00:00:00 2001 From: Tucker Johnson Date: Mon, 23 Jun 2025 08:25:18 -0400 Subject: [PATCH] generator --- core.lua | 75 +++++------- init.lua | 2 +- snippets.lua | 320 +++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 351 insertions(+), 46 deletions(-) create mode 100644 snippets.lua diff --git a/core.lua b/core.lua index 5a7ba3b..8d46c90 100644 --- a/core.lua +++ b/core.lua @@ -1,7 +1,4 @@ -local ok, bit = pcall(require, "bit32") -if not ok then - bit = require("bit") -- LuaJIT fallback -end +bit = require("bit") local M = {} @@ -14,9 +11,12 @@ local function load_csv() local f = io.open(csvPath, "r") if not f then error("Could not open setclasses.csv at " .. csvPath) end for line in f:lines() do - local dec, setClass = line:match("^(%d+),%[?([0-9AB ]+)%]?") - if dec and setClass then - lookup[tonumber(dec)] = "[" .. setClass .. "]" + local dec, sc, icv = line:match("^(%d+),%[?([0-9AB ]+)%]?%,%<([0-9 ]+)%>$") + if dec and sc then + lookup[tonumber(dec)] = { + sc = "[" .. sc .. "]", + icv = "<" .. icv .. ">" + } end end f:close() @@ -42,58 +42,43 @@ function M.analyze_set(inputStr) end local function parse_set(input) - local pcs = {} - for i = 1, #input do - local char = input:sub(i, i):upper() - local pc = tonumber(char) - if char == 'A' then pc = 10 - elseif char == 'B' then pc = 11 end - if pc and pc >= 0 and pc < 12 then - table.insert(pcs, pc) - else - error("Invalid pc: " .. char) + input = input:upper():gsub("[^0-9ABCDEF]","") + local mask = 0 + for i =1, #input do + local char = input:sub(i, i) + local pc = tonumber(char, 16) + if pc and pc <= 11 then + mask = bit.bor(mask, bit.lshift(1, pc)) end end - return pcs + return mask end - local function compute_bitmask(pcs) - local mask = 0 - for _, pc in ipairs(pcs) do - mask = bit.bor(mask, bit.lshift(1, pc)) + local function m5_bitmask(mask) + local m5Map = { [0]=0,5,10,3,8,1,6,11,4,9,2,7 } + local result = 0 + for i = 0, 11 do + if bit.band(mask, bit.lshift(1, i)) ~= 0 then + local mapped = m5Map[i] + result = bit.bor(result, bit.lshift(1, mapped)) + end end - return mask + return result end local function complement(mask) return bit.band(bit.bnot(mask), 0xFFF) end - local function multiply_set(pcs, factor) - local result = {} - for _, pc in ipairs(pcs) do - table.insert(result, (pc * factor) % 12) - end - return result - end - - local pcs = parse_set(inputStr) - local mask = compute_bitmask(pcs) - local decimal = mask - local setClass = lookup[decimal] or "[Unknown]" - local complementMask = complement(mask) - local complementClass = lookup[complementMask] or "[Unknown]" - local m5Set = multiply_set(pcs, 5) - local m5Mask = compute_bitmask(m5Set) - local m5Decimal = m5Mask - local m5Class = lookup[m5Decimal] or "[Unknown]" + local mask = parse_set(inputStr) return { input = inputStr, - decimal = decimal, - set_class = setClass, - complement_class = complementClass, - m5_class = m5Class + decimal = mask, + set_class = lookup[mask].sc, + interval_class_vector = lookup[mask].icv, + complement_class = lookup[complement(mask)].sc, + m5_class = lookup[m5_bitmask(mask)].sc } end diff --git a/init.lua b/init.lua index 2f4b9f8..df795fe 100644 --- a/init.lua +++ b/init.lua @@ -6,7 +6,7 @@ function M.analyze_selection() local result = core.analyze_set(input) if result then vim.notify("input: {" .. result.input .. "} (" .. result.decimal .. ")") - vim.notify("set class: " .. result.set_class) + vim.notify("set class: " .. result.set_class .. " " .. result.interval_class_vector) vim.notify("abs. complement: " .. result.complement_class) vim.notify("M5: " .. result.m5_class) else diff --git a/snippets.lua b/snippets.lua new file mode 100644 index 0000000..3830767 --- /dev/null +++ b/snippets.lua @@ -0,0 +1,320 @@ +-- load csv +local bit = require("bit") +local generator = require("setclass.generator") + +local M = {} + +local pluginDir = debug.getinfo(1, "S").source:sub(2):match("(.*/)") +local csvPath = pluginDir .. "set-classes.csv" + +local csvLoaded = false +local lookup = {} + +local function load_csv() + local f = io.open(csvPath, "r") + if not f then + return false + end + for line in f:lines() do + local dec, norm, prime, icv = line:match("^(%d+),([0-9AB]+),([0-9AB]+),(<[0-9AB]+>)") + if dec then + lookup[tonumber(dec)] = { + normal = norm, + prime = prime, + icv = icv + } + end + end + f:close() + csvLoaded = true + return true +end + +-- lazy load in analyze +function M.analyze_set(inputStr) + -- Try to load the CSV, generate if needed + if not csvLoaded then + local success = load_csv() + if not success then + print("[SetClass] CSV not found — generating now...") + generator.generate_csv(csvPath) + assert(load_csv(), "Failed to load generated set-classes.csv") + end + end + + -- Now safely proceed with your parse/lookup logic... +end + +-- csv generator (generator.lua) + + +-- Progress bar helper +local function show_progress(current, total, width) + width = width or 30 + local pct = current / total + local filled = math.floor(pct * width) + local bar = string.rep("#", filled) .. string.rep("-", width - filled) + io.write(string.format("\r[%s] %3d%%", bar, pct * 100)) + io.flush() +end + +-- Main CSV generation +function M.generate_csv(path) + local out = io.open(path, "w") + local total = 4095 + for mask = 1, total do + local pcs = bitmask_to_pcs(mask) + if #pcs > 0 then + -- compute norm, prime, icv... + out:write(string.format("%d,%s,%s,%s\n", mask, norm_hex, prime_hex, icv)) + end + show_progress(mask, total) + end + out:close() + io.write("\nCSV generated: " .. path .. "\n") +end + + +-- setclass/generator.lua +local M = {} + +function M.generate_csv(path) + local bit = require("bit") + + local function bitmask_to_pcs(mask) + local pcs = {} + for i = 0, 11 do + if bit.band(mask, bit.lshift(1, i)) ~= 0 then + table.insert(pcs, i) + end + end + return pcs + end + + local function pcs_to_hex(pcs) + local out = {} + for _, pc in ipairs(pcs) do + if pc < 10 then + table.insert(out, tostring(pc)) + elseif pc == 10 then + table.insert(out, "A") + else + table.insert(out, "B") + end + end + return table.concat(out) + end + + local function rotate(t, n) + local len, out = #t, {} + for i = 1, len do + out[i] = t[((i + n - 2) % len) + 1] + end + return out + end + + local function transpose_to_zero(set) + local out, root = {}, set[1] + for i = 1, #set do out[i] = (set[i] - root) % 12 end + table.sort(out) + return out + end + + local function normal_form(pcs) + local best + for i = 0, #pcs - 1 do + local rot = rotate(pcs, i) + local dist = rot[#rot] - rot[1] + if not best or dist < (best[#best] - best[1]) or + (dist == (best[#best] - best[1]) and table.concat(rot) < table.concat(best)) then + best = rot + end + end + return transpose_to_zero(best) + end + + local function invert(pcs) + local inv = {} + for i = 1, #pcs do inv[i] = (12 - pcs[i]) % 12 end + table.sort(inv) + return inv + end + + local function prime_form(pcs) + local nf = normal_form(pcs) + local inv = normal_form(invert(nf)) + local nf_str = table.concat(transpose_to_zero(nf)) + local inv_str = table.concat(transpose_to_zero(inv)) + if nf_str < inv_str then + return transpose_to_zero(nf) + else + return transpose_to_zero(inv) + end + end + + local function interval_class_vector(pcs) + local icv = {0, 0, 0, 0, 0, 0} + for i = 1, #pcs - 1 do + for j = i + 1, #pcs do + local ic = math.min((pcs[j] - pcs[i]) % 12, (pcs[i] - pcs[j]) % 12) + if ic >= 1 and ic <= 6 then + icv[ic] = icv[ic] + 1 + end + end + end + local function hex(n) + if n < 10 then return tostring(n) + elseif n == 10 then return "A" + elseif n == 11 then return "B" end + end + return "<" .. table.concat(vim.tbl_map(hex, icv)) .. ">" + end + + local out = io.open(path, "w") + for mask = 0, 4095 do + local pcs = bitmask_to_pcs(mask) + if #pcs > 0 then + local nf = normal_form(pcs) + local pf = prime_form(pcs) + local icv = interval_class_vector(pcs) + out:write(string.format("%d,%s,%s,%s\n", + mask, pcs_to_hex(nf), pcs_to_hex(pf), icv + )) + end + end + out:close() + print("Set-class CSV generated at: " .. path) +end + +return M + + +-- standalone script: + + +local bit = require("bit") + +local function bitmask_to_pcs(mask) + local pcs = {} + for i = 0, 11 do + if bit.band(mask, bit.lshift(1, i)) ~= 0 then + table.insert(pcs, i) + end + end + return pcs +end + +-- Converts a PC list to hex string like "014A9" +local function pcs_to_hex(pcs) + local symbols = {} + for _, pc in ipairs(pcs) do + if pc < 10 then + table.insert(symbols, tostring(pc)) + elseif pc == 10 then + table.insert(symbols, "A") + else + table.insert(symbols, "B") + end + end + return table.concat(symbols) +end + +-- Rotate a set +local function rotate(tbl, n) + local out = {} + local len = #tbl + for i = 1, len do + out[i] = tbl[((i + n - 2) % len) + 1] + end + return out +end + +-- Transpose set so first element is 0 +local function transpose_to_zero(set) + local transposed = {} + local root = set[1] + for i = 1, #set do + table.insert(transposed, (set[i] - root) % 12) + end + table.sort(transposed) + return transposed +end + +-- Normal form (compact rotation) +local function normal_form(pcs) + local best = nil + for i = 0, #pcs - 1 do + local rotated = rotate(pcs, i) + local dist = rotated[#rotated] - rotated[1] + if not best or dist < (best[#best] - best[1]) or + (dist == (best[#best] - best[1]) and table.concat(rotated) < table.concat(best)) then + best = rotated + end + end + return transpose_to_zero(best) +end + +-- Prime form (lowest between normal and inversion) +local function invert(pcs) + local inv = {} + for i = 1, #pcs do + inv[i] = (12 - pcs[i]) % 12 + end + table.sort(inv) + return inv +end + +local function prime_form(pcs) + local nf = normal_form(pcs) + local inv = invert(nf) + local nif = normal_form(inv) + + local nf_str = table.concat(transpose_to_zero(nf)) + local nif_str = table.concat(transpose_to_zero(nif)) + + if nf_str < nif_str then + return transpose_to_zero(nf) + else + return transpose_to_zero(nif) + end +end + +-- ICV as hex +local function interval_class_vector(pcs) + local icv = {0, 0, 0, 0, 0, 0} + for i = 1, #pcs - 1 do + for j = i + 1, #pcs do + local ic = math.min((pcs[j] - pcs[i]) % 12, (pcs[i] - pcs[j]) % 12) + if ic >= 1 and ic <= 6 then + icv[ic] = icv[ic] + 1 + end + end + end + + local function hex(n) + if n < 10 then return tostring(n) + elseif n == 10 then return "A" + elseif n == 11 then return "B" end + end + + return "<" .. table.concat(vim.tbl_map(hex, icv)) .. ">" +end + +-- Write CSV +local out = io.open("set-classes.csv", "w") +for mask = 0, 4095 do + local pcs = bitmask_to_pcs(mask) + if #pcs > 0 then + local norm = normal_form(pcs) + local prime = prime_form(pcs) + local icv = interval_class_vector(pcs) + out:write(string.format("%d,%s,%s,%s\n", + mask, + pcs_to_hex(norm), + pcs_to_hex(prime), + icv + )) + end +end +out:close() +print("CSV generated.") -- 2.39.5