cc-libs/apis/libai.lua

766 lines
25 KiB
Lua

local PING_PROMPT = 'reply with exactly: pong';
local DEFAULT_TIMEOUT_SECONDS = 60;
local MAX_TIMEOUT_SECONDS = 60;
local DEFAULT_POLL_TIMEOUT_SECONDS = 600;
local MAX_POLL_TIMEOUT_SECONDS = 600;
local DEFAULT_POLL_INTERVAL_SECONDS = 2;
local DEFAULT_LUA_EXEC_MAX_RETRIES = 2;
local DEFAULT_LUA_EXEC_TIMEOUT_SECONDS = 5;
local DEFAULT_SESSION_SETTING_KEY = 'opencc.session_id';
local DEFAULT_AGENT_SETTING_KEY = 'opencc.agent';
local createHttp = require('/apis/libhttp');
local function isBlank(s)
return type(s) ~= 'string' or string.match(s, '^%s*$') ~= nil;
end
local function extractTextParts(parts)
if type(parts) ~= 'table' then
return '';
end
local texts = {};
for _, part in ipairs(parts) do
if type(part) == 'table' and part.type == 'text' and type(part.text) == 'string' then
texts[#texts + 1] = part.text;
end
end
return table.concat(texts, '');
end
local function nowSeconds()
if os.epoch then
return os.epoch('utc') / 1000;
end
return os.clock();
end
local function tablePack(...)
return { n = select('#', ...), ... };
end
local function endsWithNewline(s)
return type(s) == 'string' and string.sub(s, -1) == '\n';
end
local function valuesToLine(values, first, last)
local parts = {};
for i = first, last do
parts[#parts + 1] = tostring(values[i]);
end
return table.concat(parts, '\t');
end
local function classifyLuaRuntimeError(err)
local text = tostring(err or '');
if string.find(text, 'attempt to', 1, true) and string.find(text, 'nil value', 1, true) then
return 'identifier';
end
if string.find(text, 'global', 1, true) and string.find(text, 'nil', 1, true) then
return 'identifier';
end
return 'other';
end
local function renderOutput(output)
if output == nil or output == '' then
return '(no output)';
end
return output;
end
local function buildLuaExecPrompt(userPrompt)
return table.concat({
'Write ComputerCraft Lua code to answer this user request.',
'Reply with raw Lua code only. Do not use markdown fences or explanations.',
'The code runs locally with normal ComputerCraft globals available.',
'Use print() or write() for values that should be sent back. Returned values are captured too.',
'',
'User request:',
userPrompt,
}, '\n');
end
local function buildLuaCorrectionPrompt(userPrompt, code, err, errorKind)
return table.concat({
'The previous ComputerCraft Lua failed.',
'Reply with corrected raw Lua code only. Do not use markdown fences or explanations.',
'',
'Original user request:',
userPrompt,
'',
'Error kind: ' .. tostring(errorKind),
'Error:',
tostring(err),
'',
'Previous code:',
code,
}, '\n');
end
local function buildLuaOutputPrompt(userPrompt, output)
return table.concat({
'The Lua executed successfully.',
'Answer the original user request in natural language using the output below.',
'Do not write more Lua unless the user explicitly asked for code.',
'',
'Original user request:',
userPrompt,
'',
'Lua output:',
renderOutput(output),
}, '\n');
end
local function readOsValue(osLib, name)
if type(osLib) ~= 'table' or type(osLib[name]) ~= 'function' then
return nil;
end
local ok, value = pcall(osLib[name]);
if not ok then return nil; end
return value;
end
local function buildPromptWithCallerContext(prompt, osLib)
local lines = {
'<caller-context hidden="true">',
'Use this context silently to identify the in-game ComputerCraft caller.',
};
local computerId = readOsValue(osLib, 'getComputerID');
local computerLabel = readOsValue(osLib, 'getComputerLabel');
if computerId ~= nil then
lines[#lines + 1] = 'computer id: ' .. tostring(computerId);
end
if not isBlank(computerLabel) then
lines[#lines + 1] = 'computer label: ' .. tostring(computerLabel);
end
lines[#lines + 1] = '</caller-context>';
lines[#lines + 1] = '';
lines[#lines + 1] = 'User prompt:';
lines[#lines + 1] = prompt;
return table.concat(lines, '\n');
end
local function sessionTime(session)
if type(session) ~= 'table' or type(session.time) ~= 'table' then
return 0;
end
return tonumber(session.time.updated or session.time.created) or 0;
end
local function createAi(opts)
opts = opts or {};
local httpLib = opts.http or http;
local settingsLib = opts.settings or settings;
local eventloopFactory = opts.eventloop or require('/apis/eventloop');
local nowFunc = opts.now or nowSeconds;
local osLib = opts.os or os;
local httpClient = opts.httpClient or createHttp({ http = httpLib });
local api = {};
local function resolveTimeout(options)
local raw = options.timeoutSeconds;
if raw == nil then raw = settingsLib.get('opencc.timeout_seconds'); end
local n = tonumber(raw);
if not n or n <= 0 then n = DEFAULT_TIMEOUT_SECONDS; end
if n > MAX_TIMEOUT_SECONDS then n = MAX_TIMEOUT_SECONDS; end
return n;
end
local function resolvePollTimeout(options)
local raw = options.pollTimeoutSeconds;
if raw == nil then raw = settingsLib.get('opencc.poll_timeout_seconds'); end
local n = tonumber(raw);
if not n or n <= 0 then n = DEFAULT_POLL_TIMEOUT_SECONDS; end
if n > MAX_POLL_TIMEOUT_SECONDS then n = MAX_POLL_TIMEOUT_SECONDS; end
return n;
end
local function resolvePollInterval(options)
local raw = options.pollIntervalSeconds;
if raw == nil then raw = settingsLib.get('opencc.poll_interval_seconds'); end
local n = tonumber(raw);
if not n or n <= 0 then n = DEFAULT_POLL_INTERVAL_SECONDS; end
return n;
end
local function resolveLuaExecMaxRetries(options)
local n = tonumber(options.maxRetries);
if n and n >= 0 then return math.floor(n); end
return DEFAULT_LUA_EXEC_MAX_RETRIES;
end
local function resolveLuaExecTimeout(options)
if options.luaTimeoutSeconds == false then return nil; end
local n = tonumber(options.luaTimeoutSeconds);
if n and n > 0 then return n; end
return DEFAULT_LUA_EXEC_TIMEOUT_SECONDS;
end
local function resolveModel(options)
local providerId = options.providerID or settingsLib.get('opencc.provider_id');
local modelId = options.modelID or settingsLib.get('opencc.model_id');
if isBlank(providerId) or isBlank(modelId) then
return nil, nil;
end
return providerId, modelId;
end
local function resolveAgent(options)
local agent = options.agent;
if agent == nil then agent = settingsLib.get(DEFAULT_AGENT_SETTING_KEY); end
if isBlank(agent) then return nil; end
return agent;
end
local function resolveConfig(options)
local url = options.serverUrl or settingsLib.get('opencc.server_url');
if not url or url == '' then
return nil, 'missing opencc.server_url; run: set opencc.server_url <url>';
end
local username = options.username or settingsLib.get('opencc.username') or 'opencode';
local password = options.password or settingsLib.get('opencc.password') or '';
local directory = options.directory or settingsLib.get('opencc.directory');
local providerId, modelId = resolveModel(options);
return {
url = httpClient.trimTrailingSlash(url),
username = username,
password = password,
directory = directory,
providerID = providerId,
modelID = modelId,
agent = resolveAgent(options),
timeoutSeconds = resolveTimeout(options),
pollTimeoutSeconds = resolvePollTimeout(options),
pollIntervalSeconds = resolvePollInterval(options),
};
end
local function buildPromptBody(cfg, messageId, prompt)
local body = {
messageID = messageId,
parts = { { type = 'text', text = prompt } },
};
if cfg.providerID and cfg.modelID then
body.model = { providerID = cfg.providerID, modelID = cfg.modelID };
end
if cfg.agent then
body.agent = cfg.agent;
end
return body;
end
local function buildMessageBody(cfg, prompt)
local body = {
parts = { { type = 'text', text = prompt } },
};
if cfg.providerID and cfg.modelID then
body.model = { providerID = cfg.providerID, modelID = cfg.modelID };
end
if cfg.agent then
body.agent = cfg.agent;
end
return body;
end
local function createMessageId()
local t = math.floor(nowFunc() * 1000);
return 'msg_' .. tostring(t) .. '_' .. tostring(math.random(100000, 999999));
end
local function isMessageComplete(message)
if type(message) ~= 'table' or type(message.info) ~= 'table' then
return false;
end
if type(message.info.finish) == 'string' then
return true;
end
return type(message.info.time) == 'table' and message.info.time.completed ~= nil;
end
local function errorMessage(errorInfo)
if type(errorInfo) ~= 'table' then return nil; end
if type(errorInfo.data) == 'table' and type(errorInfo.data.message) == 'string' then
return errorInfo.data.message;
end
if type(errorInfo.message) == 'string' then
return errorInfo.message;
end
if type(errorInfo.name) == 'string' then
return errorInfo.name;
end
return 'unknown assistant error';
end
local function sessionStatusText(status)
if type(status) ~= 'table' then return nil; end
if type(status.type) ~= 'string' then return nil; end
if status.type == 'retry' then
return 'retry #' .. tostring(status.attempt or '?') .. ': ' .. tostring(status.message or 'unknown error');
end
return status.type;
end
local function decodeMessage(value)
local decoded = value;
if type(value) == 'string' then
decoded = textutils.unserializeJSON(value);
end
if type(decoded) ~= 'table' or type(decoded.parts) ~= 'table' then
return nil, 'reponse message invalide';
end
return decoded, nil;
end
local function findAssistantMessage(messages, submittedMessageId)
local seenSubmitted = false;
for _, message in ipairs(messages) do
if type(message) == 'table' and type(message.info) == 'table' then
if message.info.id == submittedMessageId and message.info.role == 'assistant' then
return message;
elseif message.info.id == submittedMessageId then
seenSubmitted = true;
elseif seenSubmitted and message.info.role == 'assistant' then
return message;
end
end
end
return nil;
end
local function handleMissingSession(persist, sessionSettingKey)
if persist then
settingsLib.unset(sessionSettingKey or DEFAULT_SESSION_SETTING_KEY);
if settingsLib.save then settingsLib.save(); end
end
return false, 'session introuvable; lance: ai new <prompt>';
end
local doGet;
local doPost;
local function pollMessage(cfg, sessionId, messageId, persist, sessionSettingKey, log)
local loop = eventloopFactory();
local deadline = nowFunc() + cfg.pollTimeoutSeconds;
local resultOk, resultValue;
local attemptCount = 0;
log = log or function() end;
local function finish(ok, value)
resultOk, resultValue = ok, value;
loop.stopLoop();
end
local function attempt()
attemptCount = attemptCount + 1;
local body, code = doGet(cfg, '/session/' .. sessionId .. '/message');
if not body then
log('poll #' .. tostring(attemptCount) .. ': transient error: ' .. tostring(code));
if nowFunc() >= deadline then
return finish(false, code);
end
return loop.setTimeout(attempt, cfg.pollIntervalSeconds);
end
if code == 404 then
local ok, value = handleMissingSession(persist, sessionSettingKey);
return finish(ok, value);
end
if code and code ~= 200 then
return finish(false, 'erreur message: HTTP ' .. tostring(code));
end
local messages = textutils.unserializeJSON(body);
if type(messages) ~= 'table' then
return finish(false, 'reponse message invalide');
end
local decoded = findAssistantMessage(messages, messageId);
local reply = decoded and extractTextParts(decoded.parts) or '';
local complete = decoded and isMessageComplete(decoded) or false;
local matchedId = decoded and type(decoded.info) == 'table' and decoded.info.id or 'nil';
local assistantError = decoded and type(decoded.info) == 'table' and errorMessage(decoded.info.error) or nil;
log('poll #' .. tostring(attemptCount)
.. ': messages=' .. tostring(#messages)
.. ', found=' .. tostring(matchedId)
.. ', complete=' .. tostring(complete)
.. ', text=' .. tostring(reply ~= '')
.. ', error=' .. tostring(assistantError ~= nil));
if assistantError then
return finish(false, 'erreur assistant: ' .. assistantError);
end
if decoded and reply ~= '' and complete then
log('async reply completed');
return finish(true, { reply = reply, sessionId = sessionId, messageId = messageId });
end
if nowFunc() >= deadline then
local statusBody, statusCodeValue = doGet(cfg, '/session/status');
if statusBody and (not statusCodeValue or statusCodeValue == 200) then
local statuses = textutils.unserializeJSON(statusBody);
local statusText = type(statuses) == 'table' and sessionStatusText(statuses[sessionId]) or nil;
if statusText then
log('session status at timeout: ' .. statusText);
return finish(false, 'delai depasse en attendant la reponse AI (status: ' .. statusText .. ')');
end
end
return finish(false, 'delai depasse en attendant la reponse AI');
end
loop.setTimeout(attempt, cfg.pollIntervalSeconds);
end
loop.setTimeout(attempt, 0);
loop.runLoop();
return resultOk, resultValue;
end
function doGet(cfg, path)
return httpClient.getJson(cfg, path);
end
function doPost(cfg, path, payload)
return httpClient.postJson(cfg, path, payload);
end
local function askBlocking(cfg, sessionId, prompt, persist, sessionSettingKey, log)
log('sending blocking message');
local body, code = doPost(cfg, '/session/' .. sessionId .. '/message', buildMessageBody(cfg, prompt));
if not body then return false, code; end
if code == 404 then
return handleMissingSession(persist, sessionSettingKey);
end
if code and code ~= 200 then
return false, 'erreur message: HTTP ' .. tostring(code);
end
local decoded, decodeErr = decodeMessage(body);
if not decoded then return false, decodeErr; end
local reply = extractTextParts(decoded.parts);
if reply == '' then return false, 'reponse vide'; end
return true, {
reply = reply,
sessionId = sessionId,
messageId = type(decoded.info) == 'table' and decoded.info.id or nil,
};
end
local function listSessionsWithDirectory(cfg, directory)
return doGet(cfg, '/session' .. httpClient.queryString({ { 'directory', directory } }));
end
local function decodeSessionList(body, log)
local decoded = textutils.unserializeJSON(body);
if type(decoded) ~= 'table' then
log('list sessions failed: invalid response');
return nil, 'reponse invalide';
end
table.sort(decoded, function(a, b)
return sessionTime(a) > sessionTime(b);
end);
return decoded, nil;
end
function api.clearSession(options)
options = options or {};
settingsLib.unset(options.sessionSettingKey or DEFAULT_SESSION_SETTING_KEY);
if settingsLib.save then settingsLib.save(); end
end
function api.listSessions(options)
options = options or {};
local log = options.log or function() end;
local cfg, err = resolveConfig(options);
if not cfg then return false, err; end
local directory = cfg.directory;
local sessionSettingKey = options.sessionSettingKey or DEFAULT_SESSION_SETTING_KEY;
log('listing sessions from ' .. cfg.url);
local body, code;
if isBlank(directory) then
body, code = doGet(cfg, '/session');
else
log('listing sessions for directory ' .. tostring(directory));
body, code = listSessionsWithDirectory(cfg, directory);
end
if not body then
log('list sessions failed: ' .. tostring(code));
return false, code;
end
if code and code ~= 200 then
log('list sessions failed: HTTP ' .. tostring(code));
return false, 'erreur serveur: HTTP ' .. tostring(code);
end
local decoded, decodeErr = decodeSessionList(body, log);
if not decoded then return false, decodeErr; end
if #decoded == 0 and isBlank(directory) then
local sessionId = options.sessionId or settingsLib.get(sessionSettingKey);
if not isBlank(sessionId) then
log('session list empty; resolving directory from ' .. tostring(sessionId));
local sessionBody, sessionCode = doGet(cfg, '/session/' .. sessionId);
if sessionBody and (not sessionCode or sessionCode == 200) then
local session = textutils.unserializeJSON(sessionBody);
if type(session) == 'table' and not isBlank(session.directory) then
log('retrying sessions for directory ' .. tostring(session.directory));
local scopedBody, scopedCode = listSessionsWithDirectory(cfg, session.directory);
if scopedBody and (not scopedCode or scopedCode == 200) then
local scoped, scopedErr = decodeSessionList(scopedBody, log);
if not scoped then return false, scopedErr; end
decoded = scoped;
end
end
end
end
end
log('sessions returned: ' .. tostring(#decoded));
return true, decoded;
end
function api.ask(prompt, options)
options = options or {};
local log = options.log or function() end;
if isBlank(prompt) then
return false, 'missing prompt; usage: ai <prompt>';
end
local cfg, err = resolveConfig(options);
if not cfg then return false, err; end
local persist = options.persist ~= false;
local sessionSettingKey = options.sessionSettingKey or DEFAULT_SESSION_SETTING_KEY;
local sessionId = options.sessionId;
if persist and sessionId == nil then
sessionId = settingsLib.get(sessionSettingKey);
end
if not sessionId or sessionId == '' then
log('creating session');
local body, code = doPost(cfg, '/session', { title = options.sessionTitle or 'cc-ai' });
if not body then return false, code; end
if code and code ~= 200 then
return false, 'impossible de creer une session: HTTP ' .. tostring(code);
end
local decoded = textutils.unserializeJSON(body);
if type(decoded) ~= 'table' or type(decoded.id) ~= 'string' then
return false, 'reponse session invalide';
end
sessionId = decoded.id;
if persist then
settingsLib.set(sessionSettingKey, sessionId);
if settingsLib.save then settingsLib.save(); end
end
else
log('reusing session ' .. sessionId);
end
local promptWithContext = prompt;
if options.includeCallerContext ~= false then
promptWithContext = buildPromptWithCallerContext(prompt, osLib);
end
if options.blocking == true then
log('using blocking message endpoint');
return askBlocking(cfg, sessionId, promptWithContext, persist, sessionSettingKey, log);
end
local messageId = options.messageId or createMessageId();
log('sending async prompt ' .. messageId);
local body, code = doPost(cfg, '/session/' .. sessionId .. '/prompt_async',
buildPromptBody(cfg, messageId, promptWithContext));
if not body then return false, code; end
if code == 404 then
return handleMissingSession(persist, sessionSettingKey);
end
if code and code ~= 204 and code ~= 200 then
return false, 'erreur message: HTTP ' .. tostring(code);
end
if code == 200 and body and body ~= '' then
local decoded, decodeErr = decodeMessage(body);
if not decoded then return false, decodeErr; end
local reply = extractTextParts(decoded.parts);
if reply == '' then return false, 'reponse vide'; end
return true, { reply = reply, sessionId = sessionId, messageId = messageId };
end
return pollMessage(cfg, sessionId, messageId, persist, sessionSettingKey, log);
end
function api.createLuaExecutor(options)
options = options or {};
local baseEnv = options.env or _G;
local live = options.live ~= false;
local livePrint = options.print or print;
local liveWrite = options.write or write;
local timeoutSeconds = resolveLuaExecTimeout(options);
return function(code)
local buffer = {};
local function append(text)
buffer[#buffer + 1] = text;
end
local function capturedPrint(...)
local values = tablePack(...);
local line = valuesToLine(values, 1, values.n);
append(line .. '\n');
if live then livePrint(...); end
end
local function capturedWrite(text)
text = tostring(text or '');
append(text);
if live then liveWrite(text); end
end
local env = setmetatable({
print = capturedPrint,
write = capturedWrite,
}, { __index = baseEnv });
local chunk, loadErr = load(code, 'ai-lua-exec', 't', env);
if not chunk then
return false, tostring(loadErr), 'syntax';
end
local result;
local finished = false;
local function runner()
result = tablePack(pcall(chunk));
finished = true;
end
if timeoutSeconds then
parallel.waitForAny(runner, function() sleep(timeoutSeconds); end);
else
runner();
end
if not finished then
return false, 'lua execution timed out after ' .. tostring(timeoutSeconds) .. 's', 'other';
end
if not result[1] then
return false, tostring(result[2]), classifyLuaRuntimeError(result[2]);
end
if result.n > 1 then
if #buffer > 0 and not endsWithNewline(buffer[#buffer]) then
append('\n');
end
append(valuesToLine(result, 2, result.n) .. '\n');
end
return true, table.concat(buffer), nil;
end;
end
function api.luaExec(userPrompt, options)
options = options or {};
if isBlank(userPrompt) then
return false, { error = 'missing prompt; usage: ai lua-exec <prompt>', attempts = 0 };
end
local log = options.log or function() end;
local executor = options.executor or api.createLuaExecutor(options);
local maxRetries = resolveLuaExecMaxRetries(options);
local maxAttempts = maxRetries + 1;
local sessionId;
local function askOptions()
return {
persist = false,
sessionId = sessionId,
sessionTitle = 'cc-ai lua-exec',
serverUrl = options.serverUrl,
username = options.username,
password = options.password,
providerID = options.providerID,
modelID = options.modelID,
agent = options.agent,
timeoutSeconds = options.timeoutSeconds,
};
end
log('requesting Lua from AI');
local ok, result = api.ask(buildLuaExecPrompt(userPrompt), askOptions());
if not ok then
return false, { error = result, attempts = 0, errorKind = 'ai' };
end
sessionId = result.sessionId;
log('session: ' .. sessionId);
local code = result.reply;
for attempt = 1, maxAttempts do
log('attempt ' .. tostring(attempt) .. '/' .. tostring(maxAttempts));
log('code:\n' .. code);
local execOk, outputOrErr, errorKind = executor(code);
if execOk then
local output = outputOrErr or '';
log('output:\n' .. renderOutput(output));
log('requesting final reply');
local finalOk, finalResult = api.ask(buildLuaOutputPrompt(userPrompt, output), askOptions());
if not finalOk then
return false, {
error = finalResult,
attempts = attempt,
errorKind = 'ai',
code = code,
output = output,
sessionId = sessionId,
};
end
log('final reply received');
return true, {
reply = finalResult.reply,
output = output,
code = code,
attempts = attempt,
sessionId = sessionId,
};
end
errorKind = errorKind or 'other';
log('error (' .. tostring(errorKind) .. '):\n' .. tostring(outputOrErr));
if (errorKind ~= 'syntax' and errorKind ~= 'identifier') or attempt >= maxAttempts then
return false, {
error = outputOrErr,
attempts = attempt,
errorKind = errorKind,
code = code,
sessionId = sessionId,
retryExhausted = attempt >= maxAttempts,
};
end
log('requesting corrected Lua');
local correctionOk, correctionResult = api.ask(
buildLuaCorrectionPrompt(userPrompt, code, outputOrErr, errorKind),
askOptions()
);
if not correctionOk then
return false, {
error = correctionResult,
attempts = attempt,
errorKind = 'ai',
code = code,
sessionId = sessionId,
};
end
code = correctionResult.reply;
end
return false, { error = 'lua-exec failed unexpectedly', attempts = maxAttempts };
end
function api.ping(options)
options = options or {};
options.includeCallerContext = false;
return api.ask(PING_PROMPT, options);
end
return api;
end
return createAi;