569 lines
17 KiB
Lua
569 lines
17 KiB
Lua
local PING_PROMPT = 'reply with exactly: pong';
|
|
|
|
local DEFAULT_TIMEOUT_SECONDS = 60;
|
|
local MAX_TIMEOUT_SECONDS = 60;
|
|
local DEFAULT_POLL_TIMEOUT_SECONDS = 300;
|
|
local DEFAULT_POLL_INTERVAL_SECONDS = 2;
|
|
local DEFAULT_LUA_EXEC_MAX_RETRIES = 2;
|
|
local DEFAULT_LUA_EXEC_TIMEOUT_SECONDS = 5;
|
|
|
|
local B64 = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/';
|
|
|
|
local function base64encode(s)
|
|
local pad = (3 - #s % 3) % 3;
|
|
s = s .. string.rep('\0', pad);
|
|
local r = {};
|
|
for i = 1, #s, 3 do
|
|
local a, b, c = s:byte(i), s:byte(i + 1), s:byte(i + 2);
|
|
local n = a * 65536 + b * 256 + c;
|
|
r[#r + 1] = B64:sub(math.floor(n / 262144) % 64 + 1, math.floor(n / 262144) % 64 + 1)
|
|
.. B64:sub(math.floor(n / 4096) % 64 + 1, math.floor(n / 4096) % 64 + 1)
|
|
.. B64:sub(math.floor(n / 64) % 64 + 1, math.floor(n / 64) % 64 + 1)
|
|
.. B64:sub(n % 64 + 1, n % 64 + 1);
|
|
end
|
|
local result = table.concat(r);
|
|
if pad > 0 then
|
|
result = result:sub(1, #result - pad) .. string.rep('=', pad);
|
|
end
|
|
return result;
|
|
end
|
|
|
|
local function trimTrailingSlash(s)
|
|
return (s:gsub('/+$', ''));
|
|
end
|
|
|
|
local function isBlank(s)
|
|
return type(s) ~= 'string' or string.match(s, '^%s*$') ~= nil;
|
|
end
|
|
|
|
local function readAllAndClose(response)
|
|
local body = response.readAll();
|
|
response.close();
|
|
return body;
|
|
end
|
|
|
|
local function statusCode(response)
|
|
if response.getResponseCode then
|
|
return response.getResponseCode();
|
|
end
|
|
return nil;
|
|
end
|
|
|
|
local function extractTextParts(parts)
|
|
local texts = {};
|
|
for _, part in ipairs(parts) do
|
|
if part.type == 'text' and type(part.text) == 'string' then
|
|
texts[#texts + 1] = part.text;
|
|
end
|
|
end
|
|
return table.concat(texts, '');
|
|
end
|
|
|
|
local function nowSeconds()
|
|
if os.epoch then
|
|
return os.epoch('utc') / 1000;
|
|
end
|
|
return os.clock();
|
|
end
|
|
|
|
local function tablePack(...)
|
|
return { n = select('#', ...), ... };
|
|
end
|
|
|
|
local function endsWithNewline(s)
|
|
return type(s) == 'string' and string.sub(s, -1) == '\n';
|
|
end
|
|
|
|
local function valuesToLine(values, first, last)
|
|
local parts = {};
|
|
for i = first, last do
|
|
parts[#parts + 1] = tostring(values[i]);
|
|
end
|
|
return table.concat(parts, '\t');
|
|
end
|
|
|
|
local function classifyLuaRuntimeError(err)
|
|
local text = tostring(err or '');
|
|
if string.find(text, 'attempt to', 1, true) and string.find(text, 'nil value', 1, true) then
|
|
return 'identifier';
|
|
end
|
|
if string.find(text, 'global', 1, true) and string.find(text, 'nil', 1, true) then
|
|
return 'identifier';
|
|
end
|
|
return 'other';
|
|
end
|
|
|
|
local function renderOutput(output)
|
|
if output == nil or output == '' then
|
|
return '(no output)';
|
|
end
|
|
return output;
|
|
end
|
|
|
|
local function buildLuaExecPrompt(userPrompt)
|
|
return table.concat({
|
|
'Write ComputerCraft Lua code to answer this user request.',
|
|
'Reply with raw Lua code only. Do not use markdown fences or explanations.',
|
|
'The code runs locally with normal ComputerCraft globals available.',
|
|
'Use print() or write() for values that should be sent back. Returned values are captured too.',
|
|
'',
|
|
'User request:',
|
|
userPrompt,
|
|
}, '\n');
|
|
end
|
|
|
|
local function buildLuaCorrectionPrompt(userPrompt, code, err, errorKind)
|
|
return table.concat({
|
|
'The previous ComputerCraft Lua failed.',
|
|
'Reply with corrected raw Lua code only. Do not use markdown fences or explanations.',
|
|
'',
|
|
'Original user request:',
|
|
userPrompt,
|
|
'',
|
|
'Error kind: ' .. tostring(errorKind),
|
|
'Error:',
|
|
tostring(err),
|
|
'',
|
|
'Previous code:',
|
|
code,
|
|
}, '\n');
|
|
end
|
|
|
|
local function buildLuaOutputPrompt(userPrompt, output)
|
|
return table.concat({
|
|
'The Lua executed successfully.',
|
|
'Answer the original user request in natural language using the output below.',
|
|
'Do not write more Lua unless the user explicitly asked for code.',
|
|
'',
|
|
'Original user request:',
|
|
userPrompt,
|
|
'',
|
|
'Lua output:',
|
|
renderOutput(output),
|
|
}, '\n');
|
|
end
|
|
|
|
local function sessionTime(session)
|
|
if type(session) ~= 'table' or type(session.time) ~= 'table' then
|
|
return 0;
|
|
end
|
|
return tonumber(session.time.updated or session.time.created) or 0;
|
|
end
|
|
|
|
local function createAi(opts)
|
|
opts = opts or {};
|
|
|
|
local httpLib = opts.http or http;
|
|
local settingsLib = opts.settings or settings;
|
|
local sleepFunc = opts.sleep or sleep;
|
|
local nowFunc = opts.now or nowSeconds;
|
|
|
|
local api = {};
|
|
|
|
local function resolveTimeout(options)
|
|
local raw = options.timeoutSeconds;
|
|
if raw == nil then raw = settingsLib.get('opencc.timeout_seconds'); end
|
|
local n = tonumber(raw);
|
|
if not n or n <= 0 then n = DEFAULT_TIMEOUT_SECONDS; end
|
|
if n > MAX_TIMEOUT_SECONDS then n = MAX_TIMEOUT_SECONDS; end
|
|
return n;
|
|
end
|
|
|
|
local function resolvePollTimeout(options)
|
|
local raw = options.pollTimeoutSeconds;
|
|
if raw == nil then raw = settingsLib.get('opencc.poll_timeout_seconds'); end
|
|
local n = tonumber(raw);
|
|
if not n or n <= 0 then n = DEFAULT_POLL_TIMEOUT_SECONDS; end
|
|
return n;
|
|
end
|
|
|
|
local function resolvePollInterval(options)
|
|
local raw = options.pollIntervalSeconds;
|
|
if raw == nil then raw = settingsLib.get('opencc.poll_interval_seconds'); end
|
|
local n = tonumber(raw);
|
|
if not n or n <= 0 then n = DEFAULT_POLL_INTERVAL_SECONDS; end
|
|
return n;
|
|
end
|
|
|
|
local function resolveLuaExecMaxRetries(options)
|
|
local n = tonumber(options.maxRetries);
|
|
if n and n >= 0 then return math.floor(n); end
|
|
return DEFAULT_LUA_EXEC_MAX_RETRIES;
|
|
end
|
|
|
|
local function resolveLuaExecTimeout(options)
|
|
if options.luaTimeoutSeconds == false then return nil; end
|
|
local n = tonumber(options.luaTimeoutSeconds);
|
|
if n and n > 0 then return n; end
|
|
return DEFAULT_LUA_EXEC_TIMEOUT_SECONDS;
|
|
end
|
|
|
|
local function resolveConfig(options)
|
|
local url = options.serverUrl or settingsLib.get('opencc.server_url');
|
|
if not url or url == '' then
|
|
return nil, 'missing opencc.server_url; run: set opencc.server_url <url>';
|
|
end
|
|
local username = options.username or settingsLib.get('opencc.username') or 'opencode';
|
|
local password = options.password or settingsLib.get('opencc.password') or '';
|
|
return {
|
|
url = trimTrailingSlash(url),
|
|
username = username,
|
|
password = password,
|
|
timeoutSeconds = resolveTimeout(options),
|
|
pollTimeoutSeconds = resolvePollTimeout(options),
|
|
pollIntervalSeconds = resolvePollInterval(options),
|
|
};
|
|
end
|
|
|
|
local function createMessageId()
|
|
local t = math.floor(nowFunc() * 1000);
|
|
return 'cc_' .. tostring(t) .. '_' .. tostring(math.random(100000, 999999));
|
|
end
|
|
|
|
local function isMessageComplete(message)
|
|
if type(message) ~= 'table' or type(message.info) ~= 'table' then
|
|
return false;
|
|
end
|
|
if type(message.info.finish) == 'string' then
|
|
return true;
|
|
end
|
|
return type(message.info.time) == 'table' and message.info.time.completed ~= nil;
|
|
end
|
|
|
|
local function decodeMessage(body)
|
|
local decoded = textutils.unserializeJSON(body);
|
|
if type(decoded) ~= 'table' or type(decoded.parts) ~= 'table' then
|
|
return nil, 'reponse message invalide';
|
|
end
|
|
return decoded, nil;
|
|
end
|
|
|
|
local function handleMissingSession(persist)
|
|
if persist then
|
|
settingsLib.unset('opencc.session_id');
|
|
if settingsLib.save then settingsLib.save(); end
|
|
end
|
|
return false, 'session introuvable; lance: ai new <prompt>';
|
|
end
|
|
|
|
local doGet;
|
|
local doPost;
|
|
|
|
local function pollMessage(cfg, sessionId, messageId, persist)
|
|
local deadline = nowFunc() + cfg.pollTimeoutSeconds;
|
|
while true do
|
|
local body, code = doGet(cfg, '/session/' .. sessionId .. '/message/' .. messageId);
|
|
if not body then return false, code; end
|
|
if code == 404 then return handleMissingSession(persist); end
|
|
if code and code ~= 200 then
|
|
return false, 'erreur message: HTTP ' .. tostring(code);
|
|
end
|
|
|
|
local decoded, decodeErr = decodeMessage(body);
|
|
if not decoded then return false, decodeErr; end
|
|
local reply = extractTextParts(decoded.parts);
|
|
if reply ~= '' and isMessageComplete(decoded) then
|
|
return true, { reply = reply, sessionId = sessionId, messageId = messageId };
|
|
end
|
|
if nowFunc() >= deadline then
|
|
return false, 'delai depasse en attendant la reponse AI';
|
|
end
|
|
sleepFunc(cfg.pollIntervalSeconds);
|
|
end
|
|
end
|
|
|
|
local function buildHeaders(cfg)
|
|
local headers = {
|
|
['Content-Type'] = 'application/json',
|
|
['Accept'] = 'application/json',
|
|
};
|
|
if cfg.password and cfg.password ~= '' then
|
|
headers['Authorization'] = 'Basic ' .. base64encode(cfg.username .. ':' .. cfg.password);
|
|
end
|
|
return headers;
|
|
end
|
|
|
|
local function callHttp(method, request)
|
|
local ok, response, httpErr, errorResponse = pcall(httpLib[method], request);
|
|
if not ok then
|
|
return nil, 'http ' .. method .. ' threw: ' .. tostring(response);
|
|
end
|
|
response = response or errorResponse;
|
|
if not response then
|
|
return nil, 'serveur injoignable: ' .. tostring(httpErr or 'unknown error');
|
|
end
|
|
local code = statusCode(response);
|
|
local body = readAllAndClose(response);
|
|
return body, code;
|
|
end
|
|
|
|
function doGet(cfg, path)
|
|
return callHttp('get', {
|
|
url = cfg.url .. path,
|
|
headers = buildHeaders(cfg),
|
|
timeout = cfg.timeoutSeconds,
|
|
});
|
|
end
|
|
|
|
function doPost(cfg, path, payload)
|
|
return callHttp('post', {
|
|
url = cfg.url .. path,
|
|
body = textutils.serializeJSON(payload),
|
|
headers = buildHeaders(cfg),
|
|
timeout = cfg.timeoutSeconds,
|
|
});
|
|
end
|
|
|
|
function api.clearSession()
|
|
settingsLib.unset('opencc.session_id');
|
|
if settingsLib.save then settingsLib.save(); end
|
|
end
|
|
|
|
function api.listSessions(options)
|
|
options = options or {};
|
|
local cfg, err = resolveConfig(options);
|
|
if not cfg then return false, err; end
|
|
|
|
local body, code = doGet(cfg, '/session');
|
|
if not body then return false, code; end
|
|
if code and code ~= 200 then
|
|
return false, 'erreur serveur: HTTP ' .. tostring(code);
|
|
end
|
|
|
|
local decoded = textutils.unserializeJSON(body);
|
|
if type(decoded) ~= 'table' then
|
|
return false, 'reponse invalide';
|
|
end
|
|
table.sort(decoded, function(a, b)
|
|
return sessionTime(a) > sessionTime(b);
|
|
end);
|
|
return true, decoded;
|
|
end
|
|
|
|
function api.ask(prompt, options)
|
|
options = options or {};
|
|
if isBlank(prompt) then
|
|
return false, 'missing prompt; usage: ai <prompt>';
|
|
end
|
|
|
|
local cfg, err = resolveConfig(options);
|
|
if not cfg then return false, err; end
|
|
|
|
local persist = options.persist ~= false;
|
|
local sessionId = options.sessionId;
|
|
if persist and sessionId == nil then
|
|
sessionId = settingsLib.get('opencc.session_id');
|
|
end
|
|
|
|
if not sessionId or sessionId == '' then
|
|
local body, code = doPost(cfg, '/session', { title = options.sessionTitle or 'cc-ai' });
|
|
if not body then return false, code; end
|
|
if code and code ~= 200 then
|
|
return false, 'impossible de creer une session: HTTP ' .. tostring(code);
|
|
end
|
|
local decoded = textutils.unserializeJSON(body);
|
|
if type(decoded) ~= 'table' or type(decoded.id) ~= 'string' then
|
|
return false, 'reponse session invalide';
|
|
end
|
|
sessionId = decoded.id;
|
|
if persist then
|
|
settingsLib.set('opencc.session_id', sessionId);
|
|
if settingsLib.save then settingsLib.save(); end
|
|
end
|
|
end
|
|
|
|
local messageId = options.messageId or createMessageId();
|
|
local body, code = doPost(cfg, '/session/' .. sessionId .. '/prompt_async', {
|
|
messageID = messageId,
|
|
parts = { { type = 'text', text = prompt } },
|
|
});
|
|
if not body then return false, code; end
|
|
if code == 404 then
|
|
return handleMissingSession(persist);
|
|
end
|
|
if code and code ~= 204 and code ~= 200 then
|
|
return false, 'erreur message: HTTP ' .. tostring(code);
|
|
end
|
|
|
|
if code == 200 and body and body ~= '' then
|
|
local decoded, decodeErr = decodeMessage(body);
|
|
if not decoded then return false, decodeErr; end
|
|
local reply = extractTextParts(decoded.parts);
|
|
if reply == '' then return false, 'reponse vide'; end
|
|
return true, { reply = reply, sessionId = sessionId, messageId = messageId };
|
|
end
|
|
|
|
return pollMessage(cfg, sessionId, messageId, persist);
|
|
end
|
|
|
|
function api.createLuaExecutor(options)
|
|
options = options or {};
|
|
local baseEnv = options.env or _G;
|
|
local live = options.live ~= false;
|
|
local livePrint = options.print or print;
|
|
local liveWrite = options.write or write;
|
|
local timeoutSeconds = resolveLuaExecTimeout(options);
|
|
|
|
return function(code)
|
|
local buffer = {};
|
|
|
|
local function append(text)
|
|
buffer[#buffer + 1] = text;
|
|
end
|
|
|
|
local function capturedPrint(...)
|
|
local values = tablePack(...);
|
|
local line = valuesToLine(values, 1, values.n);
|
|
append(line .. '\n');
|
|
if live then livePrint(...); end
|
|
end
|
|
|
|
local function capturedWrite(text)
|
|
text = tostring(text or '');
|
|
append(text);
|
|
if live then liveWrite(text); end
|
|
end
|
|
|
|
local env = setmetatable({
|
|
print = capturedPrint,
|
|
write = capturedWrite,
|
|
}, { __index = baseEnv });
|
|
local chunk, loadErr = load(code, 'ai-lua-exec', 't', env);
|
|
if not chunk then
|
|
return false, tostring(loadErr), 'syntax';
|
|
end
|
|
|
|
local result;
|
|
local finished = false;
|
|
local function runner()
|
|
result = tablePack(pcall(chunk));
|
|
finished = true;
|
|
end
|
|
|
|
if timeoutSeconds then
|
|
parallel.waitForAny(runner, function() sleep(timeoutSeconds); end);
|
|
else
|
|
runner();
|
|
end
|
|
|
|
if not finished then
|
|
return false, 'lua execution timed out after ' .. tostring(timeoutSeconds) .. 's', 'other';
|
|
end
|
|
if not result[1] then
|
|
return false, tostring(result[2]), classifyLuaRuntimeError(result[2]);
|
|
end
|
|
if result.n > 1 then
|
|
if #buffer > 0 and not endsWithNewline(buffer[#buffer]) then
|
|
append('\n');
|
|
end
|
|
append(valuesToLine(result, 2, result.n) .. '\n');
|
|
end
|
|
return true, table.concat(buffer), nil;
|
|
end;
|
|
end
|
|
|
|
function api.luaExec(userPrompt, options)
|
|
options = options or {};
|
|
if isBlank(userPrompt) then
|
|
return false, { error = 'missing prompt; usage: ai lua-exec <prompt>', attempts = 0 };
|
|
end
|
|
|
|
local log = options.log or function() end;
|
|
local executor = options.executor or api.createLuaExecutor(options);
|
|
local maxRetries = resolveLuaExecMaxRetries(options);
|
|
local maxAttempts = maxRetries + 1;
|
|
local sessionId;
|
|
|
|
local function askOptions()
|
|
return {
|
|
persist = false,
|
|
sessionId = sessionId,
|
|
sessionTitle = 'cc-ai lua-exec',
|
|
serverUrl = options.serverUrl,
|
|
username = options.username,
|
|
password = options.password,
|
|
timeoutSeconds = options.timeoutSeconds,
|
|
};
|
|
end
|
|
|
|
log('requesting Lua from AI');
|
|
local ok, result = api.ask(buildLuaExecPrompt(userPrompt), askOptions());
|
|
if not ok then
|
|
return false, { error = result, attempts = 0, errorKind = 'ai' };
|
|
end
|
|
sessionId = result.sessionId;
|
|
log('session: ' .. sessionId);
|
|
|
|
local code = result.reply;
|
|
for attempt = 1, maxAttempts do
|
|
log('attempt ' .. tostring(attempt) .. '/' .. tostring(maxAttempts));
|
|
log('code:\n' .. code);
|
|
|
|
local execOk, outputOrErr, errorKind = executor(code);
|
|
if execOk then
|
|
local output = outputOrErr or '';
|
|
log('output:\n' .. renderOutput(output));
|
|
log('requesting final reply');
|
|
local finalOk, finalResult = api.ask(buildLuaOutputPrompt(userPrompt, output), askOptions());
|
|
if not finalOk then
|
|
return false, {
|
|
error = finalResult,
|
|
attempts = attempt,
|
|
errorKind = 'ai',
|
|
code = code,
|
|
output = output,
|
|
sessionId = sessionId,
|
|
};
|
|
end
|
|
log('final reply received');
|
|
return true, {
|
|
reply = finalResult.reply,
|
|
output = output,
|
|
code = code,
|
|
attempts = attempt,
|
|
sessionId = sessionId,
|
|
};
|
|
end
|
|
|
|
errorKind = errorKind or 'other';
|
|
log('error (' .. tostring(errorKind) .. '):\n' .. tostring(outputOrErr));
|
|
if (errorKind ~= 'syntax' and errorKind ~= 'identifier') or attempt >= maxAttempts then
|
|
return false, {
|
|
error = outputOrErr,
|
|
attempts = attempt,
|
|
errorKind = errorKind,
|
|
code = code,
|
|
sessionId = sessionId,
|
|
retryExhausted = attempt >= maxAttempts,
|
|
};
|
|
end
|
|
|
|
log('requesting corrected Lua');
|
|
local correctionOk, correctionResult = api.ask(
|
|
buildLuaCorrectionPrompt(userPrompt, code, outputOrErr, errorKind),
|
|
askOptions()
|
|
);
|
|
if not correctionOk then
|
|
return false, {
|
|
error = correctionResult,
|
|
attempts = attempt,
|
|
errorKind = 'ai',
|
|
code = code,
|
|
sessionId = sessionId,
|
|
};
|
|
end
|
|
code = correctionResult.reply;
|
|
end
|
|
|
|
return false, { error = 'lua-exec failed unexpectedly', attempts = maxAttempts };
|
|
end
|
|
|
|
function api.ping(options)
|
|
return api.ask(PING_PROMPT, options);
|
|
end
|
|
|
|
return api;
|
|
end
|
|
|
|
return createAi;
|