From: Tucker Johnson Date: Mon, 23 Jun 2025 15:39:24 +0000 (-0400) Subject: generator-included X-Git-Url: https://git.newer.systems/?a=commitdiff_plain;h=6cc1ad359acbc3e155520d2888f0a4369909e09d;p=setclass.git generator-included --- diff --git a/core.lua b/core.lua index 8d46c90..2448465 100644 --- a/core.lua +++ b/core.lua @@ -1,4 +1,5 @@ bit = require("bit") +local generator = require("setclass.generator") local M = {} @@ -9,19 +10,25 @@ local lookup = {} local function load_csv() local f = io.open(csvPath, "r") - if not f then error("Could not open setclasses.csv at " .. csvPath) end + if not f then + return false + end for line in f:lines() do - local dec, sc, icv = line:match("^(%d+),%[?([0-9AB ]+)%]?%,%<([0-9 ]+)%>$") - if dec and sc then + local dec, set, sc, icv = line:match("^(%d+),([0-9AB]+),([0-9AB]+),(<[0-9AB]+>)") + if dec then lookup[tonumber(dec)] = { - sc = "[" .. sc .. "]", - icv = "<" .. icv .. ">" + set = set, + sc = sc, + icv = icv } end end f:close() + csvLoaded = true + return true end + function M.get_visual_selection() local start_pos = vim.fn.getpos("'<") local end_pos = vim.fn.getpos("'>") @@ -36,9 +43,14 @@ function M.get_visual_selection() end function M.analyze_set(inputStr) + if not csvLoaded then - load_csv() - csvLoaded = true + local success = load_csv() + if not success then + print("[SetClass] CSV not found — generating now...") + generator.generate_csv(csvPath) + assert(load_csv(), "Failed to load generated set-classes.csv") + end end local function parse_set(input) @@ -73,7 +85,7 @@ function M.analyze_set(inputStr) local mask = parse_set(inputStr) return { - input = inputStr, + input = lookup[mask].set, decimal = mask, set_class = lookup[mask].sc, interval_class_vector = lookup[mask].icv, diff --git a/generate_setClasses.lua b/generate_setClasses.lua deleted file mode 100644 index ce0a19e..0000000 --- a/generate_setClasses.lua +++ /dev/null @@ -1,112 +0,0 @@ -local bit32 = require("bit32") or require("bit") - --- Convert pitch class to A/B notation -local function pc_to_char(pc) - if pc < 10 then return tostring(pc) - elseif pc == 10 then return 'A' - elseif pc == 11 then return 'B' end -end - --- Generate a string for a PC set -local function pcs_to_string(pcs) - local out = {} - for _, pc in ipairs(pcs) do table.insert(out, pc_to_char(pc)) end - return "[" .. table.concat(out, "") .. "]" -end - --- Get all rotations of a sorted set -local function rotations(pcs) - local out = {} - local n = #pcs - for i = 1, n do - local rot = {} - local offset = pcs[i] - for j = 0, n - 1 do - local val = (pcs[(i + j - 1) % n + 1] - offset + 12) % 12 - table.insert(rot, val) - end - table.insert(out, rot) - end - return out -end - --- Get normal form of a set (most packed rotation) -local function normal_form(pcs) - table.sort(pcs) - local all_rots = rotations(pcs) - - -- Find the most compact (smallest span) - table.sort(all_rots, function(a, b) - local span_a = (a[#a] - a[1]) % 12 - local span_b = (b[#b] - b[1]) % 12 - if span_a ~= span_b then return span_a < span_b end - - -- If equal span, choose lex smallest - for i = 1, #a do - if a[i] ~= b[i] then return a[i] < b[i] end - end - return false - end) - - return all_rots[1] -end - --- Invert a set mod 12 -local function invert(pcs) - local inv = {} - for _, pc in ipairs(pcs) do - table.insert(inv, (12 - pc) % 12) - end - return inv -end - --- Get prime form of a set -local function prime_form(pcs) - local nf1 = normal_form(pcs) - local nf2 = normal_form(invert(pcs)) - - -- Lexical comparison - for i = 1, #nf1 do - if nf1[i] < nf2[i] then return nf1 - elseif nf1[i] > nf2[i] then return nf2 end - end - return nf1 -end - --- Convert bitmask to PC set -local function bitmask_to_pcs(n) - local pcs = {} - for i = 0, 11 do - if bit32.band(n, bit32.lshift(1, i)) ~= 0 then - table.insert(pcs, i) - end - end - return pcs -end - -local function interval_class_vector(pcs) - local icv = {0, 0, 0, 0, 0, 0} - table.sort(pcs) - for i = 1, #pcs do - for j = i + 1, #pcs do - local interval = (pcs[j] - pcs[i]) % 12 - local ic = math.min(interval, 12 - interval) - if ic > 0 and ic <= 6 then - icv[ic] = icv[ic] +1 - end - end - end - return table.concat(icv) -end - --- Write the CSV -local output = io.open("set-classes.csv", "w") -for i = 1, 4095 do -- skip 0 (empty set) - local pcs = bitmask_to_pcs(i) - local pf = prime_form(pcs) - local icv = interval_class_vector(pcs) - output:write(i .. "," .. pcs_to_string(pf) .. ",<" .. icv .. ">\n") -end -output:close() - -print("✅ setclasses.csv generated with correct prime forms.") diff --git a/generator.lua b/generator.lua new file mode 100644 index 0000000..7eda678 --- /dev/null +++ b/generator.lua @@ -0,0 +1,91 @@ +local M = {} + +local bit = require("bit") + +-- Convert bitmask to PC list (integers) +local function mask_to_pcs(mask) + local pcs = {} + for i = 0, 11 do + if bit.band(mask, bit.lshift(1, i)) ~= 0 then + table.insert(pcs, i) + end + end + return pcs +end + +-- Convert PC list to hex string (e.g., 10 -> A, 11 -> B) +local function pcs_to_hex(pcs) + local out = {} + for _, pc in ipairs(pcs) do + table.insert(out, string.format("%X", pc)) + end + return table.concat(out) +end + +-- Transpose bitmask left by n +local function transpose(mask, n) + return bit.band(bit.lshift(mask, n) + bit.rshift(mask, 12 - n), 0xFFF) +end + +-- Invert a bitmask (I-transform) +local function invert(mask) + local result = 0 + for i = 0, 11 do + if bit.band(mask, bit.lshift(1, i)) ~= 0 then + local inv = (12 - i) % 12 + result = bit.bor(result, bit.lshift(1, inv)) + end + end + return result +end + +-- Find prime form of bitmask +local function prime_form_bitmask(mask) + local min = 0xFFF + 1 + for i = 0, 11 do + local t = transpose(mask, i) + local ti = transpose(invert(mask), i) + if t < min then min = t end + if ti < min then min = ti end + end + return min +end + +-- Generate ICV from bitmask +local function icv_from_mask(mask) + local pcs = mask_to_pcs(mask) + local icv = {0, 0, 0, 0, 0, 0} + for i = 1, #pcs - 1 do + for j = i + 1, #pcs do + local a, b = pcs[i], pcs[j] + local ic = math.min((a - b) % 12, (b - a) % 12) + if ic >= 1 and ic <= 6 then + icv[ic] = icv[ic] + 1 + end + end + end + local out = {} + for _, v in ipairs(icv) do + table.insert(out, string.format("%X", v)) + end + return "<" .. table.concat(out) .. ">" +end + +-- Generate CSV with format: decimal,pc_order,prime_form,icv +function M.generate_csv(path) + local out = io.open(path, "w") + local total = 4096 + for mask = 0, total - 1 do + local pcs = mask_to_pcs(mask) + if #pcs > 0 then + local pc_order = pcs_to_hex(pcs) + local prime_mask = prime_form_bitmask(mask) + local prime_form = pcs_to_hex(mask_to_pcs(prime_mask)) + local icv = icv_from_mask(mask) + out:write(string.format("%d,%s,%s,%s\n", mask, pc_order, prime_form, icv)) + end + end + out:close() +end + +return M diff --git a/init.lua b/init.lua index df795fe..5365bd0 100644 --- a/init.lua +++ b/init.lua @@ -5,10 +5,10 @@ function M.analyze_selection() local input = core.get_visual_selection() local result = core.analyze_set(input) if result then - vim.notify("input: {" .. result.input .. "} (" .. result.decimal .. ")") - vim.notify("set class: " .. result.set_class .. " " .. result.interval_class_vector) - vim.notify("abs. complement: " .. result.complement_class) - vim.notify("M5: " .. result.m5_class) + vim.notify(" input: " .. result.input .. " (" .. result.decimal .. ")") + vim.notify(" set class: " .. result.set_class .. " " .. result.interval_class_vector) + vim.notify("abs. complement: " .. result.complement_class) + vim.notify(" m5 class: " .. result.m5_class) else vim.notify("unable to parse collection") end diff --git a/snippets.lua b/snippets.lua deleted file mode 100644 index 3830767..0000000 --- a/snippets.lua +++ /dev/null @@ -1,320 +0,0 @@ --- load csv -local bit = require("bit") -local generator = require("setclass.generator") - -local M = {} - -local pluginDir = debug.getinfo(1, "S").source:sub(2):match("(.*/)") -local csvPath = pluginDir .. "set-classes.csv" - -local csvLoaded = false -local lookup = {} - -local function load_csv() - local f = io.open(csvPath, "r") - if not f then - return false - end - for line in f:lines() do - local dec, norm, prime, icv = line:match("^(%d+),([0-9AB]+),([0-9AB]+),(<[0-9AB]+>)") - if dec then - lookup[tonumber(dec)] = { - normal = norm, - prime = prime, - icv = icv - } - end - end - f:close() - csvLoaded = true - return true -end - --- lazy load in analyze -function M.analyze_set(inputStr) - -- Try to load the CSV, generate if needed - if not csvLoaded then - local success = load_csv() - if not success then - print("[SetClass] CSV not found — generating now...") - generator.generate_csv(csvPath) - assert(load_csv(), "Failed to load generated set-classes.csv") - end - end - - -- Now safely proceed with your parse/lookup logic... -end - --- csv generator (generator.lua) - - --- Progress bar helper -local function show_progress(current, total, width) - width = width or 30 - local pct = current / total - local filled = math.floor(pct * width) - local bar = string.rep("#", filled) .. string.rep("-", width - filled) - io.write(string.format("\r[%s] %3d%%", bar, pct * 100)) - io.flush() -end - --- Main CSV generation -function M.generate_csv(path) - local out = io.open(path, "w") - local total = 4095 - for mask = 1, total do - local pcs = bitmask_to_pcs(mask) - if #pcs > 0 then - -- compute norm, prime, icv... - out:write(string.format("%d,%s,%s,%s\n", mask, norm_hex, prime_hex, icv)) - end - show_progress(mask, total) - end - out:close() - io.write("\nCSV generated: " .. path .. "\n") -end - - --- setclass/generator.lua -local M = {} - -function M.generate_csv(path) - local bit = require("bit") - - local function bitmask_to_pcs(mask) - local pcs = {} - for i = 0, 11 do - if bit.band(mask, bit.lshift(1, i)) ~= 0 then - table.insert(pcs, i) - end - end - return pcs - end - - local function pcs_to_hex(pcs) - local out = {} - for _, pc in ipairs(pcs) do - if pc < 10 then - table.insert(out, tostring(pc)) - elseif pc == 10 then - table.insert(out, "A") - else - table.insert(out, "B") - end - end - return table.concat(out) - end - - local function rotate(t, n) - local len, out = #t, {} - for i = 1, len do - out[i] = t[((i + n - 2) % len) + 1] - end - return out - end - - local function transpose_to_zero(set) - local out, root = {}, set[1] - for i = 1, #set do out[i] = (set[i] - root) % 12 end - table.sort(out) - return out - end - - local function normal_form(pcs) - local best - for i = 0, #pcs - 1 do - local rot = rotate(pcs, i) - local dist = rot[#rot] - rot[1] - if not best or dist < (best[#best] - best[1]) or - (dist == (best[#best] - best[1]) and table.concat(rot) < table.concat(best)) then - best = rot - end - end - return transpose_to_zero(best) - end - - local function invert(pcs) - local inv = {} - for i = 1, #pcs do inv[i] = (12 - pcs[i]) % 12 end - table.sort(inv) - return inv - end - - local function prime_form(pcs) - local nf = normal_form(pcs) - local inv = normal_form(invert(nf)) - local nf_str = table.concat(transpose_to_zero(nf)) - local inv_str = table.concat(transpose_to_zero(inv)) - if nf_str < inv_str then - return transpose_to_zero(nf) - else - return transpose_to_zero(inv) - end - end - - local function interval_class_vector(pcs) - local icv = {0, 0, 0, 0, 0, 0} - for i = 1, #pcs - 1 do - for j = i + 1, #pcs do - local ic = math.min((pcs[j] - pcs[i]) % 12, (pcs[i] - pcs[j]) % 12) - if ic >= 1 and ic <= 6 then - icv[ic] = icv[ic] + 1 - end - end - end - local function hex(n) - if n < 10 then return tostring(n) - elseif n == 10 then return "A" - elseif n == 11 then return "B" end - end - return "<" .. table.concat(vim.tbl_map(hex, icv)) .. ">" - end - - local out = io.open(path, "w") - for mask = 0, 4095 do - local pcs = bitmask_to_pcs(mask) - if #pcs > 0 then - local nf = normal_form(pcs) - local pf = prime_form(pcs) - local icv = interval_class_vector(pcs) - out:write(string.format("%d,%s,%s,%s\n", - mask, pcs_to_hex(nf), pcs_to_hex(pf), icv - )) - end - end - out:close() - print("Set-class CSV generated at: " .. path) -end - -return M - - --- standalone script: - - -local bit = require("bit") - -local function bitmask_to_pcs(mask) - local pcs = {} - for i = 0, 11 do - if bit.band(mask, bit.lshift(1, i)) ~= 0 then - table.insert(pcs, i) - end - end - return pcs -end - --- Converts a PC list to hex string like "014A9" -local function pcs_to_hex(pcs) - local symbols = {} - for _, pc in ipairs(pcs) do - if pc < 10 then - table.insert(symbols, tostring(pc)) - elseif pc == 10 then - table.insert(symbols, "A") - else - table.insert(symbols, "B") - end - end - return table.concat(symbols) -end - --- Rotate a set -local function rotate(tbl, n) - local out = {} - local len = #tbl - for i = 1, len do - out[i] = tbl[((i + n - 2) % len) + 1] - end - return out -end - --- Transpose set so first element is 0 -local function transpose_to_zero(set) - local transposed = {} - local root = set[1] - for i = 1, #set do - table.insert(transposed, (set[i] - root) % 12) - end - table.sort(transposed) - return transposed -end - --- Normal form (compact rotation) -local function normal_form(pcs) - local best = nil - for i = 0, #pcs - 1 do - local rotated = rotate(pcs, i) - local dist = rotated[#rotated] - rotated[1] - if not best or dist < (best[#best] - best[1]) or - (dist == (best[#best] - best[1]) and table.concat(rotated) < table.concat(best)) then - best = rotated - end - end - return transpose_to_zero(best) -end - --- Prime form (lowest between normal and inversion) -local function invert(pcs) - local inv = {} - for i = 1, #pcs do - inv[i] = (12 - pcs[i]) % 12 - end - table.sort(inv) - return inv -end - -local function prime_form(pcs) - local nf = normal_form(pcs) - local inv = invert(nf) - local nif = normal_form(inv) - - local nf_str = table.concat(transpose_to_zero(nf)) - local nif_str = table.concat(transpose_to_zero(nif)) - - if nf_str < nif_str then - return transpose_to_zero(nf) - else - return transpose_to_zero(nif) - end -end - --- ICV as hex -local function interval_class_vector(pcs) - local icv = {0, 0, 0, 0, 0, 0} - for i = 1, #pcs - 1 do - for j = i + 1, #pcs do - local ic = math.min((pcs[j] - pcs[i]) % 12, (pcs[i] - pcs[j]) % 12) - if ic >= 1 and ic <= 6 then - icv[ic] = icv[ic] + 1 - end - end - end - - local function hex(n) - if n < 10 then return tostring(n) - elseif n == 10 then return "A" - elseif n == 11 then return "B" end - end - - return "<" .. table.concat(vim.tbl_map(hex, icv)) .. ">" -end - --- Write CSV -local out = io.open("set-classes.csv", "w") -for mask = 0, 4095 do - local pcs = bitmask_to_pcs(mask) - if #pcs > 0 then - local norm = normal_form(pcs) - local prime = prime_form(pcs) - local icv = interval_class_vector(pcs) - out:write(string.format("%d,%s,%s,%s\n", - mask, - pcs_to_hex(norm), - pcs_to_hex(prime), - icv - )) - end -end -out:close() -print("CSV generated.")