Page MenuHomePhabricator (Chris)

No OneTemporary

Authored By
Unknown
Size
77 KB
Referenced Files
None
Subscribers
None
diff --git a/main.lua b/main.lua
index e504c0e..703464d 100644
--- a/main.lua
+++ b/main.lua
@@ -1,94 +1,94 @@
--[[
_ _ _
_ __ ___ __ _| |_| |_ __ _| |_ __ _
| '_ ` _ \ / _` | __| __/ _` | __/ _` |
| | | | | | (_| | |_| || (_| | || (_| |
|_| |_| |_|\__,_|\__|\__\__,_|\__\__,_|
- v2.0
+ v2.1
Copyright 2020-2026 Matthew Hesketh <matthew@matthewhesketh.com>
See LICENSE for details
]]
local config = require('src.core.config')
local logger = require('src.core.logger')
local database = require('src.core.database')
local redis = require('src.core.redis')
local session = require('src.core.session')
local i18n = require('src.core.i18n')
local loader = require('src.core.loader')
local router = require('src.core.router')
local migrations = require('src.db.init')
-- 1. Load configuration
config.load('.env')
logger.init()
logger.info('mattata v%s starting...', config.VERSION)
-- 2. Validate required config
assert(config.bot_token(), 'BOT_TOKEN is required. Set it in .env or as an environment variable.')
-- 3. Configure telegram-bot-lua
-local api = require('telegram-bot-lua.core').configure(config.bot_token())
+local api = require('telegram-bot-lua').configure(config.bot_token())
local tools = require('telegram-bot-lua.tools')
logger.info('Bot: @%s (%s) [%d]', api.info.username, api.info.first_name, api.info.id)
-- 4. Connect to PostgreSQL
local db_ok, db_err = database.connect()
if not db_ok then
logger.error('Cannot start without PostgreSQL: %s', tostring(db_err))
os.exit(1)
end
-- 5. Run database migrations
migrations.run(database)
-- 6. Connect to Redis
local redis_ok, redis_err = redis.connect()
if not redis_ok then
logger.error('Cannot start without Redis: %s', tostring(redis_err))
os.exit(1)
end
session.init(redis)
-- 7. Load languages
i18n.init()
-- 8. Load all plugins
loader.init(api, database, redis)
-- 9. Build context factory and start router
local ctx_base = {
api = api,
tools = tools,
db = database,
redis = redis,
session = session,
config = config,
i18n = i18n,
permissions = require('src.core.permissions'),
logger = logger
}
router.init(api, tools, loader, ctx_base)
-- 10. Notify admins
local info_msg = string.format(
'<pre>mattata v%s connected!\n\n Username: @%s\n Name: %s\n ID: %d\n Plugins: %d</pre>',
config.VERSION,
tools.escape_html(api.info.username),
tools.escape_html(api.info.first_name),
api.info.id,
loader.count()
)
if config.log_chat() then
api.send_message(config.log_chat(), info_msg, 'html')
end
for _, admin_id in ipairs(config.bot_admins()) do
api.send_message(admin_id, info_msg, 'html')
end
-- 11. Start the bot
logger.info('Starting main loop...')
router.run()
diff --git a/spec/core/concurrency_spec.lua b/spec/core/concurrency_spec.lua
new file mode 100644
index 0000000..f743b53
--- /dev/null
+++ b/spec/core/concurrency_spec.lua
@@ -0,0 +1,168 @@
+--[[
+ Tests for v2.1 concurrency-safe pool mechanics.
+ Tests semaphore guards on database and Redis pools,
+ method-name-string safe_call, and graceful fallback
+ when semaphore is nil (test/mock environment).
+]]
+
+describe('concurrency', function()
+
+ describe('database pool safety', function()
+ local mock_db = require('spec.helpers.mock_db')
+
+ it('should work without semaphore (test environment fallback)', function()
+ local db = mock_db.new()
+ -- mock_db never calls connect(), so semaphore is nil
+ -- All operations should still work identically to v2.0
+ local result = db.query('SELECT 1')
+ assert.are.same({}, result)
+ db.set_next_result({ { id = 1 } })
+ result = db.query('SELECT * FROM users')
+ assert.are.same({ { id = 1 } }, result)
+ end)
+
+ it('should execute queries and return results without semaphore', function()
+ local db = mock_db.new()
+ db.set_next_result({ { count = 42 } })
+ local result = db.execute('SELECT COUNT(*) FROM messages', {})
+ assert.are.same({ { count = 42 } }, result)
+ end)
+
+ it('should handle transactions without semaphore', function()
+ local db = mock_db.new()
+ local called = false
+ db.transaction(function(query, execute)
+ called = true
+ query('SELECT 1')
+ end)
+ assert.is_true(called)
+ end)
+
+ it('pool_stats should report available and max_size', function()
+ local db = mock_db.new()
+ local stats = db.pool_stats()
+ assert.is_number(stats.available)
+ assert.is_number(stats.max_size)
+ end)
+ end)
+
+ describe('redis pool safety', function()
+ local mock_redis = require('spec.helpers.mock_redis')
+
+ it('should work without semaphore (test environment fallback)', function()
+ local redis = mock_redis.new()
+ -- mock_redis never calls connect() with real config, so no semaphore
+ redis.set('key', 'value')
+ assert.are.equal('value', redis.get('key'))
+ end)
+
+ it('safe_call with method name strings should work via proxy functions', function()
+ local redis = mock_redis.new()
+ -- All proxy functions in the real module use safe_call('method_name', ...)
+ -- The mock simulates this directly
+ redis.set('test_key', 'test_value')
+ assert.are.equal('test_value', redis.get('test_key'))
+ assert.is_true(redis.has_command('set'))
+ assert.is_true(redis.has_command('get'))
+ end)
+
+ it('hash operations should work through proxy', function()
+ local redis = mock_redis.new()
+ redis.hset('hash:1', 'field1', 'value1')
+ assert.are.equal('value1', redis.hget('hash:1', 'field1'))
+ local all = redis.hgetall('hash:1')
+ assert.are.equal('value1', all['field1'])
+ end)
+
+ it('set operations should work through proxy', function()
+ local redis = mock_redis.new()
+ redis.sadd('set:1', 'member1')
+ assert.are.equal(1, redis.sismember('set:1', 'member1'))
+ assert.are.equal(0, redis.sismember('set:1', 'nonexistent'))
+ end)
+
+ it('list operations should work through proxy', function()
+ local redis = mock_redis.new()
+ redis.rpush('list:1', 'a')
+ redis.rpush('list:1', 'b')
+ local items = redis.lrange('list:1', 0, -1)
+ assert.are.equal(2, #items)
+ assert.are.equal('a', items[1])
+ end)
+
+ it('scan should work through pool', function()
+ local redis = mock_redis.new()
+ redis.set('prefix:1', 'a')
+ redis.set('prefix:2', 'b')
+ redis.set('other:1', 'c')
+ local results = redis.scan('prefix:*')
+ assert.are.equal(2, #results)
+ end)
+
+ it('client() should still return a usable object', function()
+ local redis = mock_redis.new()
+ local c = redis.client()
+ assert.is_not_nil(c)
+ end)
+ end)
+
+ describe('mock_api async stubs', function()
+ local mock_api = require('spec.helpers.mock_api')
+
+ it('should have handler stubs', function()
+ local api = mock_api.new()
+ assert.is_function(api.on_message)
+ assert.is_function(api.on_edited_message)
+ assert.is_function(api.on_callback_query)
+ assert.is_function(api.on_inline_query)
+ end)
+
+ it('should have async module stubs', function()
+ local api = mock_api.new()
+ assert.is_table(api.async)
+ assert.is_function(api.async.run)
+ assert.is_function(api.async.stop)
+ assert.is_function(api.async.all)
+ assert.is_function(api.async.spawn)
+ assert.is_function(api.async.sleep)
+ assert.is_function(api.async.is_running)
+ end)
+
+ it('should have api.run stub', function()
+ local api = mock_api.new()
+ assert.is_function(api.run)
+ -- Should not error when called
+ assert.has_no.errors(function()
+ api.run({ timeout = 60, limit = 100 })
+ end)
+ end)
+
+ it('should have process_update stub', function()
+ local api = mock_api.new()
+ assert.is_function(api.process_update)
+ assert.has_no.errors(function()
+ api.process_update({ update_id = 1, message = {} })
+ end)
+ end)
+
+ it('async.is_running should return false in test environment', function()
+ local api = mock_api.new()
+ assert.is_false(api.async.is_running())
+ end)
+
+ it('async.spawn should execute the function', function()
+ local api = mock_api.new()
+ local called = false
+ api.async.spawn(function() called = true end)
+ assert.is_true(called)
+ end)
+
+ it('handler stubs should be overwritable', function()
+ local api = mock_api.new()
+ local msg_received = nil
+ api.on_message = function(msg) msg_received = msg end
+ api.on_message({ text = 'hello' })
+ assert.are.same({ text = 'hello' }, msg_received)
+ end)
+ end)
+end)
diff --git a/spec/core/config_spec.lua b/spec/core/config_spec.lua
index a1c8ddc..9f30156 100644
--- a/spec/core/config_spec.lua
+++ b/spec/core/config_spec.lua
@@ -1,364 +1,364 @@
--[[
Tests for src/core/config.lua
Config module: loading .env, get/set values, typed access (number, boolean, list).
]]
describe('core.config', function()
local config
local tmpfile
-- Write a temporary .env file for testing
local function write_env(content)
tmpfile = os.tmpname()
local f = io.open(tmpfile, 'w')
f:write(content)
f:close()
return tmpfile
end
before_each(function()
-- Clear the cached module so each test gets a fresh config
package.loaded['src.core.config'] = nil
config = require('src.core.config')
end)
after_each(function()
if tmpfile then
os.remove(tmpfile)
tmpfile = nil
end
end)
describe('load()', function()
it('should load a valid .env file', function()
local path = write_env('FOO=bar\nBAZ=qux\n')
config.load(path)
assert.are.equal('bar', config.get('FOO'))
assert.are.equal('qux', config.get('BAZ'))
end)
it('should not error on missing .env file', function()
assert.has_no.errors(function()
config.load('/tmp/nonexistent_env_file_' .. os.time())
end)
end)
it('should ignore empty lines', function()
local path = write_env('FOO=bar\n\n\nBAZ=qux\n')
config.load(path)
assert.are.equal('bar', config.get('FOO'))
assert.are.equal('qux', config.get('BAZ'))
end)
it('should ignore comment lines', function()
local path = write_env('# This is a comment\nFOO=bar\n# Another comment\n')
config.load(path)
assert.are.equal('bar', config.get('FOO'))
end)
it('should strip surrounding double quotes from values', function()
local path = write_env('FOO="hello world"\n')
config.load(path)
assert.are.equal('hello world', config.get('FOO'))
end)
it('should strip surrounding single quotes from values', function()
local path = write_env("FOO='hello world'\n")
config.load(path)
assert.are.equal('hello world', config.get('FOO'))
end)
it('should strip inline comments from unquoted values', function()
local path = write_env('FOO=bar # this is a comment\n')
config.load(path)
assert.are.equal('bar', config.get('FOO'))
end)
it('should handle values with equals signs', function()
local path = write_env('FOO=bar=baz\n')
config.load(path)
assert.are.equal('bar=baz', config.get('FOO'))
end)
it('should trim whitespace around keys and values', function()
local path = write_env(' FOO = bar \n')
config.load(path)
assert.are.equal('bar', config.get('FOO'))
end)
end)
describe('get()', function()
it('should return the value for a known key', function()
local path = write_env('MY_KEY=my_value\n')
config.load(path)
assert.are.equal('my_value', config.get('MY_KEY'))
end)
it('should return default when key is missing', function()
local path = write_env('OTHER_KEY=other\n')
config.load(path)
assert.are.equal('fallback', config.get('NONEXISTENT', 'fallback'))
end)
it('should return nil when key is missing and no default', function()
local path = write_env('')
config.load(path)
assert.is_nil(config.get('NONEXISTENT'))
end)
it('should fall back to os.getenv for empty .env values', function()
local path = write_env('EMPTY_KEY=\n')
config.load(path)
-- This will either return nil or whatever the OS env has
local result = config.get('EMPTY_KEY', 'default_val')
assert.is_not_nil(result)
end)
it('should auto-load .env if not explicitly loaded', function()
-- Just calling get() without load() should not error
assert.has_no.errors(function()
config.get('ANYTHING')
end)
end)
end)
describe('get_number()', function()
it('should return a number for numeric values', function()
local path = write_env('PORT=8080\n')
config.load(path)
assert.are.equal(8080, config.get_number('PORT'))
end)
it('should return default for non-numeric values', function()
local path = write_env('PORT=abc\n')
config.load(path)
assert.are.equal(3000, config.get_number('PORT', 3000))
end)
it('should return default when key is missing', function()
local path = write_env('')
config.load(path)
assert.are.equal(5432, config.get_number('DB_PORT', 5432))
end)
it('should return nil when key is missing and no default', function()
local path = write_env('')
config.load(path)
assert.is_nil(config.get_number('MISSING'))
end)
it('should handle float values', function()
local path = write_env('RATE=1.5\n')
config.load(path)
assert.are.equal(1.5, config.get_number('RATE'))
end)
it('should handle negative numbers', function()
local path = write_env('OFFSET=-10\n')
config.load(path)
assert.are.equal(-10, config.get_number('OFFSET'))
end)
end)
describe('is_enabled()', function()
it('should return true for "true"', function()
local path = write_env('FLAG=true\n')
config.load(path)
assert.is_true(config.is_enabled('FLAG'))
end)
it('should return true for "1"', function()
local path = write_env('FLAG=1\n')
config.load(path)
assert.is_true(config.is_enabled('FLAG'))
end)
it('should return true for "yes"', function()
local path = write_env('FLAG=yes\n')
config.load(path)
assert.is_true(config.is_enabled('FLAG'))
end)
it('should return true case-insensitively for "TRUE"', function()
local path = write_env('FLAG=TRUE\n')
config.load(path)
assert.is_true(config.is_enabled('FLAG'))
end)
it('should return true case-insensitively for "Yes"', function()
local path = write_env('FLAG=Yes\n')
config.load(path)
assert.is_true(config.is_enabled('FLAG'))
end)
it('should return false for "false"', function()
local path = write_env('FLAG=false\n')
config.load(path)
assert.is_false(config.is_enabled('FLAG'))
end)
it('should return false for "0"', function()
local path = write_env('FLAG=0\n')
config.load(path)
assert.is_false(config.is_enabled('FLAG'))
end)
it('should return false for "no"', function()
local path = write_env('FLAG=no\n')
config.load(path)
assert.is_false(config.is_enabled('FLAG'))
end)
it('should return false for missing key', function()
local path = write_env('')
config.load(path)
assert.is_false(config.is_enabled('MISSING'))
end)
it('should return false for arbitrary string', function()
local path = write_env('FLAG=maybe\n')
config.load(path)
assert.is_false(config.is_enabled('FLAG'))
end)
end)
describe('get_list()', function()
it('should split comma-separated values into a table', function()
local path = write_env('ITEMS=a,b,c\n')
config.load(path)
local list = config.get_list('ITEMS')
assert.are.equal(3, #list)
assert.are.equal('a', list[1])
assert.are.equal('b', list[2])
assert.are.equal('c', list[3])
end)
it('should trim whitespace around items', function()
local path = write_env('ITEMS= a , b , c \n')
config.load(path)
local list = config.get_list('ITEMS')
assert.are.equal(3, #list)
assert.are.equal('a', list[1])
assert.are.equal('b', list[2])
assert.are.equal('c', list[3])
end)
it('should convert numeric items to numbers', function()
local path = write_env('IDS=100,200,300\n')
config.load(path)
local list = config.get_list('IDS')
assert.are.equal(3, #list)
assert.are.equal(100, list[1])
assert.are.equal(200, list[2])
assert.are.equal(300, list[3])
end)
it('should return empty table for missing key', function()
local path = write_env('')
config.load(path)
local list = config.get_list('MISSING')
assert.are.same({}, list)
end)
it('should return empty table for empty value', function()
local path = write_env('ITEMS=\n')
config.load(path)
local list = config.get_list('ITEMS')
assert.are.same({}, list)
end)
it('should handle single-item list', function()
local path = write_env('ITEMS=only\n')
config.load(path)
local list = config.get_list('ITEMS')
assert.are.equal(1, #list)
assert.are.equal('only', list[1])
end)
it('should handle mixed numeric and string items', function()
local path = write_env('ITEMS=100,hello,300\n')
config.load(path)
local list = config.get_list('ITEMS')
assert.are.equal(100, list[1])
assert.are.equal('hello', list[2])
assert.are.equal(300, list[3])
end)
end)
describe('convenience accessors', function()
it('should return bot_token', function()
local path = write_env('BOT_TOKEN=12345:ABCDEF\n')
config.load(path)
assert.are.equal('12345:ABCDEF', config.bot_token())
end)
it('should return bot_admins as a list', function()
local path = write_env('BOT_ADMINS=221714512,123456\n')
config.load(path)
local admins = config.bot_admins()
assert.are.equal(2, #admins)
assert.are.equal(221714512, admins[1])
end)
it('should return bot_name with default', function()
local path = write_env('')
config.load(path)
assert.are.equal('mattata', config.bot_name())
end)
it('should return database config with defaults', function()
local path = write_env('')
config.load(path)
local db = config.database()
assert.are.equal('postgres', db.host)
assert.are.equal(5432, db.port)
assert.are.equal('mattata', db.database)
end)
it('should return redis config with defaults', function()
local path = write_env('')
config.load(path)
local rc = config.redis_config()
assert.are.equal('redis', rc.host)
assert.are.equal(6379, rc.port)
assert.are.equal(0, rc.db)
end)
it('should return polling config with defaults', function()
local path = write_env('')
config.load(path)
local p = config.polling()
assert.are.equal(60, p.timeout)
assert.are.equal(100, p.limit)
end)
it('should return webhook config', function()
local path = write_env('WEBHOOK_ENABLED=true\nWEBHOOK_URL=https://example.com\nWEBHOOK_PORT=8443\n')
config.load(path)
local wh = config.webhook()
assert.is_true(wh.enabled)
assert.are.equal('https://example.com', wh.url)
assert.are.equal(8443, wh.port)
end)
it('should return debug status', function()
local path = write_env('DEBUG=true\n')
config.load(path)
assert.is_true(config.debug())
end)
it('should return ai config with defaults', function()
local path = write_env('')
config.load(path)
local ai = config.ai()
assert.is_false(ai.enabled)
assert.are.equal('gpt-4o', ai.openai_model)
end)
end)
describe('VERSION', function()
- it('should be 2.0', function()
- assert.are.equal('2.0', config.VERSION)
+ it('should be 2.1', function()
+ assert.are.equal('2.1', config.VERSION)
end)
end)
end)
diff --git a/spec/helpers/mock_api.lua b/spec/helpers/mock_api.lua
index bac61ec..b7552a0 100644
--- a/spec/helpers/mock_api.lua
+++ b/spec/helpers/mock_api.lua
@@ -1,216 +1,243 @@
--[[
- mattata v2.0 - Mock Telegram Bot API
+ mattata v2.1 - Mock Telegram Bot API
Records all calls and returns configurable responses for testing.
+ Includes async/handler stubs for copas-based concurrency support.
]]
local mock_api = {}
function mock_api.new()
local api = {
info = { id = 123456789, username = 'testbot', first_name = 'Test Bot' },
calls = {},
}
local custom_handlers = {}
local function record(method, ...)
table.insert(api.calls, { method = method, args = {...} })
end
function api.send_message(chat_id, text, parse_mode, ...)
record('send_message', chat_id, text, parse_mode, ...)
return { ok = true, result = { message_id = #api.calls, chat = { id = chat_id } } }
end
function api.get_chat_member(chat_id, user_id)
record('get_chat_member', chat_id, user_id)
if custom_handlers.get_chat_member then
return custom_handlers.get_chat_member(chat_id, user_id)
end
-- Default: regular member
return { ok = true, result = { status = 'member', user = { id = user_id } } }
end
function api.ban_chat_member(chat_id, user_id, until_date)
record('ban_chat_member', chat_id, user_id, until_date)
return { ok = true, result = true }
end
function api.unban_chat_member(chat_id, user_id)
record('unban_chat_member', chat_id, user_id)
return { ok = true, result = true }
end
function api.restrict_chat_member(chat_id, user_id, perms_or_until, maybe_perms)
record('restrict_chat_member', chat_id, user_id, perms_or_until, maybe_perms)
return { ok = true, result = true }
end
function api.delete_message(chat_id, message_id)
record('delete_message', chat_id, message_id)
return { ok = true, result = true }
end
function api.pin_chat_message(chat_id, message_id, disable_notification)
record('pin_chat_message', chat_id, message_id, disable_notification)
return { ok = true, result = true }
end
function api.unpin_chat_message(chat_id, message_id)
record('unpin_chat_message', chat_id, message_id)
return { ok = true, result = true }
end
function api.get_chat(chat_id)
record('get_chat', chat_id)
return { ok = true, result = { id = chat_id, first_name = 'Test User' } }
end
function api.edit_message_text(chat_id, message_id, text, parse_mode, ...)
record('edit_message_text', chat_id, message_id, text, parse_mode, ...)
return { ok = true, result = { message_id = message_id } }
end
function api.edit_message_reply_markup(chat_id, message_id, inline_message_id, keyboard)
record('edit_message_reply_markup', chat_id, message_id, inline_message_id, keyboard)
return { ok = true, result = { message_id = message_id } }
end
function api.answer_callback_query(callback_id, text)
record('answer_callback_query', callback_id, text)
return { ok = true }
end
function api.get_updates(timeout, offset, limit, allowed)
record('get_updates', timeout, offset, limit, allowed)
return { ok = true, result = {} }
end
function api.leave_chat(chat_id)
record('leave_chat', chat_id)
return { ok = true, result = true }
end
function api.inline_keyboard()
local kb = {}
function kb:row(...)
return self
end
return kb
end
function api.row()
local r = {}
function r:callback_data_button(text, data)
return self
end
function r:url_button(text, url)
return self
end
return r
end
-- Helper to set custom get_chat_member behavior
function api.set_admin(chat_id, user_id)
local original_handler = custom_handlers.get_chat_member
custom_handlers.get_chat_member = function(cid, uid)
if cid == chat_id and uid == user_id then
return {
ok = true,
result = {
status = 'administrator',
user = { id = uid },
can_restrict_members = true,
can_delete_messages = true,
can_pin_messages = true,
can_promote_members = true,
can_invite_users = true,
}
}
end
if original_handler then
return original_handler(cid, uid)
end
return { ok = true, result = { status = 'member', user = { id = uid } } }
end
end
-- Helper to set the bot as an admin with specified permissions
function api.set_bot_admin(chat_id, perms)
perms = perms or {}
local original_handler = custom_handlers.get_chat_member
custom_handlers.get_chat_member = function(cid, uid)
if cid == chat_id and uid == api.info.id then
return {
ok = true,
result = {
status = 'administrator',
user = { id = uid },
can_restrict_members = perms.can_restrict_members or false,
can_delete_messages = perms.can_delete_messages or false,
can_pin_messages = perms.can_pin_messages or false,
can_promote_members = perms.can_promote_members or false,
can_invite_users = perms.can_invite_users or false,
}
}
end
if original_handler then
return original_handler(cid, uid)
end
return { ok = true, result = { status = 'member', user = { id = uid } } }
end
end
function api.set_creator(chat_id, user_id)
local original_handler = custom_handlers.get_chat_member
custom_handlers.get_chat_member = function(cid, uid)
if cid == chat_id and uid == user_id then
return {
ok = true,
result = {
status = 'creator',
user = { id = uid },
}
}
end
if original_handler then
return original_handler(cid, uid)
end
return { ok = true, result = { status = 'member', user = { id = uid } } }
end
end
+ -- Handler stubs (overwritten by router.run() in production)
+ api.on_message = function() end
+ api.on_edited_message = function() end
+ api.on_callback_query = function() end
+ api.on_inline_query = function() end
+
+ -- Async stubs (telegram-bot-lua async system)
+ api.async = {
+ run = function() end,
+ stop = function() end,
+ all = function(fns) return {} end,
+ spawn = function(fn) if fn then fn() end end,
+ sleep = function() end,
+ is_running = function() return false end,
+ }
+
+ -- api.run stub — no-op (prevents tests from entering copas.loop)
+ function api.run(opts)
+ record('run', opts)
+ end
+
+ -- process_update stub
+ function api.process_update(update)
+ record('process_update', update)
+ end
+
function api.reset()
api.calls = {}
custom_handlers = {}
end
function api.get_call(method)
for _, call in ipairs(api.calls) do
if call.method == method then return call end
end
return nil
end
function api.get_calls(method)
local results = {}
for _, call in ipairs(api.calls) do
if call.method == method then
table.insert(results, call)
end
end
return results
end
function api.count_calls(method)
local count = 0
for _, call in ipairs(api.calls) do
if call.method == method then count = count + 1 end
end
return count
end
return api
end
return mock_api
diff --git a/src/core/config.lua b/src/core/config.lua
index 44769b2..4366798 100644
--- a/src/core/config.lua
+++ b/src/core/config.lua
@@ -1,163 +1,163 @@
--[[
mattata v2.0 - Configuration Module
Reads configuration from .env file with os.getenv() fallback.
Provides typed access to all configuration values.
]]
local config = {}
local env_values = {}
local loaded = false
-- Parse a .env file into a table
local function parse_env_file(path)
local values = {}
local file = io.open(path, 'r')
if not file then
return values
end
for line in file:lines() do
line = line:match('^%s*(.-)%s*$') -- trim
if line ~= '' and not line:match('^#') then
local key, value = line:match('^([%w_]+)%s*=%s*(.*)$')
if key then
-- Strip surrounding quotes
value = value:match('^"(.*)"$') or value:match("^'(.*)'$") or value
-- Strip inline comments (only for unquoted values)
value = value:match('^(.-)%s+#') or value
values[key] = value
end
end
end
file:close()
return values
end
-- Load .env file (called once)
function config.load(path)
path = path or '.env'
env_values = parse_env_file(path)
loaded = true
end
-- Get a string value with optional default
function config.get(key, default)
if not loaded then
config.load()
end
local value = env_values[key]
if value == nil or value == '' then
value = os.getenv(key)
end
if value == nil or value == '' then
return default
end
return value
end
-- Get a numeric value
function config.get_number(key, default)
local value = config.get(key)
if value == nil then
return default
end
return tonumber(value) or default
end
-- Get a boolean value
function config.is_enabled(key)
local value = config.get(key)
if value == nil then
return false
end
value = value:lower()
return value == 'true' or value == '1' or value == 'yes'
end
-- Get a comma-separated list as a table
function config.get_list(key)
local value = config.get(key)
if not value or value == '' then
return {}
end
local list = {}
for item in value:gmatch('[^,]+') do
item = item:match('^%s*(.-)%s*$')
if item ~= '' then
local num = tonumber(item)
table.insert(list, num or item)
end
end
return list
end
-- Convenience accessors for common config groups
function config.bot_token()
return config.get('BOT_TOKEN')
end
function config.bot_admins()
return config.get_list('BOT_ADMINS')
end
function config.bot_name()
return config.get('BOT_NAME', 'mattata')
end
function config.database()
return {
host = config.get('DATABASE_HOST', 'postgres'),
port = config.get_number('DATABASE_PORT', 5432),
database = config.get('DATABASE_NAME', 'mattata'),
user = config.get('DATABASE_USER', 'mattata'),
password = config.get('DATABASE_PASSWORD', 'changeme')
}
end
function config.redis_config()
return {
host = config.get('REDIS_HOST', 'redis'),
port = config.get_number('REDIS_PORT', 6379),
password = config.get('REDIS_PASSWORD'),
db = config.get_number('REDIS_DB', 0)
}
end
function config.polling()
return {
timeout = config.get_number('POLLING_TIMEOUT', 60),
limit = config.get_number('POLLING_LIMIT', 100)
}
end
function config.webhook()
return {
enabled = config.is_enabled('WEBHOOK_ENABLED'),
url = config.get('WEBHOOK_URL'),
port = config.get_number('WEBHOOK_PORT', 8443),
secret = config.get('WEBHOOK_SECRET')
}
end
function config.ai()
return {
enabled = config.is_enabled('AI_ENABLED'),
openai_key = config.get('OPENAI_API_KEY'),
openai_model = config.get('OPENAI_MODEL', 'gpt-4o'),
anthropic_key = config.get('ANTHROPIC_API_KEY'),
anthropic_model = config.get('ANTHROPIC_MODEL', 'claude-sonnet-4-5-20250929')
}
end
function config.debug()
return config.is_enabled('DEBUG')
end
function config.log_chat()
return config.get_number('LOG_CHAT')
end
-- Version constant
-config.VERSION = '2.0'
+config.VERSION = '2.1'
return config
diff --git a/src/core/database.lua b/src/core/database.lua
index d832951..fe99e4a 100644
--- a/src/core/database.lua
+++ b/src/core/database.lua
@@ -1,283 +1,327 @@
--[[
- mattata v2.0 - PostgreSQL Database Module
+ mattata v2.1 - PostgreSQL Database Module
Uses pgmoon for async-compatible PostgreSQL connections.
- Implements connection pooling, automatic reconnection, and transaction helpers.
+ Implements connection pooling with copas semaphore guards,
+ automatic reconnection, and transaction helpers.
]]
local database = {}
local pgmoon = require('pgmoon')
local config = require('src.core.config')
local logger = require('src.core.logger')
+local copas_sem = require('copas.semaphore')
local pool = {}
local pool_size = 10
local pool_timeout = 30000
+local pool_semaphore = nil
local db_config = nil
-- Initialise pool configuration
local function get_config()
if not db_config then
db_config = config.database()
end
return db_config
end
-- Create a new pgmoon connection
local function create_connection()
local cfg = get_config()
local pg = pgmoon.new({
host = cfg.host,
port = cfg.port,
database = cfg.database,
user = cfg.user,
password = cfg.password
})
local ok, err = pg:connect()
if not ok then
return nil, err
end
pg:settimeout(pool_timeout)
return pg
end
function database.connect()
local cfg = get_config()
pool_size = config.get_number('DATABASE_POOL_SIZE', 10)
pool_timeout = config.get_number('DATABASE_TIMEOUT', 30000)
-- Create initial connection to validate credentials
local pg, err = create_connection()
if not pg then
logger.error('Failed to connect to PostgreSQL: %s', tostring(err))
return false, err
end
table.insert(pool, pg)
+
+ -- Create semaphore to guard concurrent pool access
+ -- max = pool_size, start = pool_size (all permits available), timeout = 30s
+ pool_semaphore = copas_sem.new(pool_size, pool_size, 30)
+
logger.info('Connected to PostgreSQL at %s:%d/%s (pool size: %d)', cfg.host, cfg.port, cfg.database, pool_size)
return true
end
-- Acquire a connection from the pool
function database.acquire()
+ -- Take a semaphore permit (blocks coroutine if pool exhausted, 30s timeout)
+ if pool_semaphore then
+ local ok, err = pool_semaphore:take(1, 30)
+ if not ok then
+ logger.error('Failed to acquire pool permit: %s', tostring(err))
+ return nil, 'Pool exhausted (semaphore timeout)'
+ end
+ end
if #pool > 0 then
return table.remove(pool)
end
-- Pool exhausted — create a new connection
local pg, err = create_connection()
if not pg then
logger.error('Failed to create new connection: %s', tostring(err))
+ -- Return the permit since we failed to use it
+ if pool_semaphore then pool_semaphore:give(1) end
return nil, err
end
return pg
end
-- Release a connection back to the pool
function database.release(pg)
if not pg then return end
if #pool < pool_size then
table.insert(pool, pg)
else
pcall(function() pg:disconnect() end)
end
+ -- Return the semaphore permit
+ if pool_semaphore then pool_semaphore:give(1) end
end
-- Execute a raw SQL query with automatic connection management
function database.query(sql, ...)
local pg, err = database.acquire()
if not pg then
logger.error('Database not connected')
return nil, 'Database not connected'
end
- local result, query_err, partial, num_queries = pg:query(sql)
+ local result, query_err, _, _ = pg:query(sql)
if not result then
-- Check for connection loss and attempt reconnect
if query_err and (query_err:match('closed') or query_err:match('broken') or query_err:match('timeout')) then
logger.warn('Connection lost, attempting reconnect...')
pcall(function() pg:disconnect() end)
+ -- Release the dead connection's permit before reconnect
+ if pool_semaphore then pool_semaphore:give(1) end
pg, err = create_connection()
if pg then
+ -- Re-acquire a permit for the new connection
+ if pool_semaphore then
+ local ok, sem_err = pool_semaphore:take(1, 30)
+ if not ok then
+ pcall(function() pg:disconnect() end)
+ logger.error('Reconnect semaphore acquire failed: %s', tostring(sem_err))
+ return nil, 'Pool exhausted during reconnect'
+ end
+ end
result, query_err = pg:query(sql)
if result then
database.release(pg)
return result
end
+ database.release(pg)
end
logger.error('Reconnect failed for query: %s', tostring(query_err or err))
return nil, query_err or err
end
logger.error('Query failed: %s\nSQL: %s', tostring(query_err), sql)
database.release(pg)
return nil, query_err
end
database.release(pg)
return result
end
-- Execute a parameterized query (manually escape values)
function database.execute(sql, params)
- local pg, err = database.acquire()
+ local pg, _ = database.acquire()
if not pg then
return nil, 'Database not connected'
end
if params then
local escaped = {}
for i, v in ipairs(params) do
if v == nil then
escaped[i] = 'NULL'
elseif type(v) == 'number' then
escaped[i] = tostring(v)
elseif type(v) == 'boolean' then
escaped[i] = v and 'TRUE' or 'FALSE'
else
escaped[i] = pg:escape_literal(tostring(v))
end
end
-- Replace $1, $2, etc. with escaped values
sql = sql:gsub('%$(%d+)', function(n)
return escaped[tonumber(n)] or '$' .. n
end)
end
local result, query_err = pg:query(sql)
if not result then
-- Attempt reconnect on connection failure
if query_err and (query_err:match('closed') or query_err:match('broken') or query_err:match('timeout')) then
logger.warn('Connection lost during execute, reconnecting...')
pcall(function() pg:disconnect() end)
+ -- Release the dead connection's permit before reconnect
+ if pool_semaphore then pool_semaphore:give(1) end
local new_pg
- new_pg, err = create_connection()
+ new_pg, _ = create_connection()
if new_pg then
+ -- Re-acquire a permit for the new connection
+ if pool_semaphore then
+ local ok, sem_err = pool_semaphore:take(1, 30)
+ if not ok then
+ pcall(function() new_pg:disconnect() end)
+ logger.error('Reconnect semaphore acquire failed: %s', tostring(sem_err))
+ return nil, 'Pool exhausted during reconnect'
+ end
+ end
result, query_err = new_pg:query(sql)
if result then
database.release(new_pg)
return result
end
database.release(new_pg)
end
else
database.release(pg)
end
logger.error('Query failed: %s\nSQL: %s', tostring(query_err), sql)
return nil, query_err
end
database.release(pg)
return result
end
-- Run a function inside a transaction (BEGIN / COMMIT / ROLLBACK)
function database.transaction(fn)
- local pg, err = database.acquire()
+ local pg, _ = database.acquire()
if not pg then
return nil, 'Database not connected'
end
local ok, begin_err = pg:query('BEGIN')
if not ok then
database.release(pg)
return nil, begin_err
end
-- Build a scoped query function for this connection
local function scoped_query(sql)
return pg:query(sql)
end
local function scoped_execute(sql, params)
if params then
local escaped = {}
for i, v in ipairs(params) do
if v == nil then
escaped[i] = 'NULL'
elseif type(v) == 'number' then
escaped[i] = tostring(v)
elseif type(v) == 'boolean' then
escaped[i] = v and 'TRUE' or 'FALSE'
else
escaped[i] = pg:escape_literal(tostring(v))
end
end
sql = sql:gsub('%$(%d+)', function(n)
return escaped[tonumber(n)] or '$' .. n
end)
end
return pg:query(sql)
end
local success, result = pcall(fn, scoped_query, scoped_execute)
if success then
pg:query('COMMIT')
database.release(pg)
return result
else
pg:query('ROLLBACK')
database.release(pg)
logger.error('Transaction failed: %s', tostring(result))
return nil, result
end
end
-- Convenience: insert and return the row
function database.insert(table_name, data)
local columns = {}
local values = {}
local params = {}
local i = 1
for k, v in pairs(data) do
table.insert(columns, k)
table.insert(values, '$' .. i)
table.insert(params, v)
i = i + 1
end
local sql = string.format(
'INSERT INTO %s (%s) VALUES (%s) RETURNING *',
table_name,
table.concat(columns, ', '),
table.concat(values, ', ')
)
return database.execute(sql, params)
end
-- Convenience: upsert (INSERT ON CONFLICT UPDATE)
function database.upsert(table_name, data, conflict_keys, update_keys)
local columns = {}
local values = {}
local params = {}
local i = 1
for k, v in pairs(data) do
table.insert(columns, k)
table.insert(values, '$' .. i)
table.insert(params, v)
i = i + 1
end
local updates = {}
for _, k in ipairs(update_keys) do
table.insert(updates, k .. ' = EXCLUDED.' .. k)
end
local sql = string.format(
'INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO UPDATE SET %s RETURNING *',
table_name,
table.concat(columns, ', '),
table.concat(values, ', '),
table.concat(conflict_keys, ', '),
table.concat(updates, ', ')
)
return database.execute(sql, params)
end
-- Get the raw pgmoon connection for advanced usage
function database.connection()
return database.acquire()
end
-- Get current pool stats
function database.pool_stats()
return {
available = #pool,
max_size = pool_size
}
end
function database.disconnect()
for _, pg in ipairs(pool) do
pcall(function() pg:disconnect() end)
end
pool = {}
+ pool_semaphore = nil
logger.info('Disconnected from PostgreSQL (pool drained)')
end
return database
diff --git a/src/core/redis.lua b/src/core/redis.lua
index 5d940d8..c75df8a 100644
--- a/src/core/redis.lua
+++ b/src/core/redis.lua
@@ -1,257 +1,313 @@
--[[
- mattata v2.0 - Redis Connection Module
+ mattata v2.1 - Redis Connection Pool Module
Redis is used as cache/session store only. PostgreSQL is the primary database.
- Includes automatic reconnection with backoff, SCAN replacement for KEYS, and pipeline support.
+ Implements connection pooling with copas semaphore guards,
+ automatic reconnection with backoff, SCAN replacement for KEYS, and pipeline support.
]]
local redis_mod = {}
local redis_lib = require('redis')
local config = require('src.core.config')
local logger = require('src.core.logger')
+local copas_sem = require('copas.semaphore')
-local client = nil
+local pool = {}
+local pool_size = 5
+local pool_semaphore = nil
local redis_cfg = nil
-local reconnect_attempts = 0
-local MAX_RECONNECT_ATTEMPTS = 10
-- Override hgetall to return key-value table instead of flat array
redis_lib.commands.hgetall = redis_lib.command('hgetall', {
response = function(response)
local result = {}
for i = 1, #response, 2 do
result[response[i]] = response[i + 1]
end
return result
end
})
-local function do_connect()
+-- Create a single Redis connection
+local function create_connection()
if not redis_cfg then
redis_cfg = config.redis_config()
end
+ local conn
local ok, err = pcall(function()
- client = redis_lib.connect({
+ conn = redis_lib.connect({
host = redis_cfg.host,
port = redis_cfg.port
})
end)
if not ok then
- return false, err
+ return nil, err
end
if redis_cfg.password and redis_cfg.password ~= '' then
- client:auth(redis_cfg.password)
+ conn:auth(redis_cfg.password)
end
if redis_cfg.db and redis_cfg.db ~= 0 then
- client:select(redis_cfg.db)
+ conn:select(redis_cfg.db)
end
- reconnect_attempts = 0
- return true
+ return conn
end
--- Automatic reconnection with exponential backoff
-local function ensure_connected()
- if client then
- -- Quick ping check
- local ok = pcall(function() client:ping() end)
- if ok then return true end
- logger.warn('Redis connection lost, attempting reconnect...')
- client = nil
+-- Acquire a connection from the pool
+local function acquire()
+ -- Take a semaphore permit (blocks coroutine if pool exhausted)
+ if pool_semaphore then
+ local ok, err = pool_semaphore:take(1, 10)
+ if not ok then
+ logger.error('Redis pool semaphore timeout: %s', tostring(err))
+ return nil, 'Redis pool exhausted'
+ end
end
- while reconnect_attempts < MAX_RECONNECT_ATTEMPTS do
- reconnect_attempts = reconnect_attempts + 1
- local backoff = math.min(2 ^ reconnect_attempts, 30)
- logger.info('Redis reconnect attempt %d/%d (backoff: %ds)', reconnect_attempts, MAX_RECONNECT_ATTEMPTS, backoff)
- local ok, err = do_connect()
+ -- Try pooled connections, discard dead ones
+ while #pool > 0 do
+ local conn = table.remove(pool)
+ local ok = pcall(function() conn:ping() end)
if ok then
- logger.info('Redis reconnected successfully')
- return true
+ return conn
end
- logger.warn('Redis reconnect failed: %s', tostring(err))
- local socket = require('socket')
- socket.sleep(backoff)
+ logger.warn('Discarding dead pooled Redis connection')
end
- logger.error('Redis reconnection failed after %d attempts', MAX_RECONNECT_ATTEMPTS)
- return false
+ -- Create fresh connection
+ local conn, err = create_connection()
+ if not conn then
+ logger.error('Failed to create Redis connection: %s', tostring(err))
+ if pool_semaphore then pool_semaphore:give(1) end
+ return nil, err
+ end
+ return conn
+end
+
+-- Release a connection back to the pool
+local function release(conn)
+ if not conn then return end
+ if #pool < pool_size then
+ table.insert(pool, conn)
+ else
+ pcall(function() conn:quit() end)
+ end
+ if pool_semaphore then pool_semaphore:give(1) end
+end
+
+-- Discard a connection without returning it to the pool
+local function discard(conn)
+ if conn then
+ pcall(function() conn:quit() end)
+ end
+ if pool_semaphore then pool_semaphore:give(1) end
end
-- Safe command wrapper with auto-reconnect
-local function safe_call(method, ...)
- if not ensure_connected() then
+-- method_name is a string like 'get', 'set', etc.
+local function safe_call(method_name, ...)
+ local conn, err = acquire()
+ if not conn then
return nil
end
- local ok, result = pcall(method, client, ...)
+ local ok, result = pcall(function(...)
+ return conn[method_name](conn, ...)
+ end, ...)
if not ok then
- -- Connection may have dropped mid-call
- logger.warn('Redis command failed: %s — retrying after reconnect', tostring(result))
- client = nil
- if ensure_connected() then
- ok, result = pcall(method, client, ...)
- if ok then return result end
+ -- Connection may have dropped mid-call — discard and retry once
+ logger.warn('Redis %s failed: %s — retrying after reconnect', method_name, tostring(result))
+ discard(conn)
+ conn, err = acquire()
+ if not conn then
+ logger.error('Redis reconnect failed: %s', tostring(err))
+ return nil
+ end
+ ok, result = pcall(function(...)
+ return conn[method_name](conn, ...)
+ end, ...)
+ if not ok then
+ logger.error('Redis %s failed after reconnect: %s', method_name, tostring(result))
+ discard(conn)
+ return nil
end
- logger.error('Redis command failed after reconnect: %s', tostring(result))
- return nil
end
+ release(conn)
return result
end
function redis_mod.connect()
redis_cfg = config.redis_config()
- local ok, err = do_connect()
- if not ok then
+ pool_size = config.get_number('REDIS_POOL_SIZE', 5)
+
+ -- Create initial connection to validate credentials
+ local conn, err = create_connection()
+ if not conn then
logger.error('Failed to connect to Redis: %s', tostring(err))
return false, err
end
- logger.info('Connected to Redis at %s:%d (db %d)', redis_cfg.host, redis_cfg.port, redis_cfg.db or 0)
+ table.insert(pool, conn)
+
+ -- Create semaphore to guard concurrent pool access
+ -- max = pool_size, start = pool_size (all permits available), timeout = 10s
+ pool_semaphore = copas_sem.new(pool_size, pool_size, 10)
+
+ logger.info('Connected to Redis at %s:%d (db %d, pool size: %d)', redis_cfg.host, redis_cfg.port, redis_cfg.db or 0, pool_size)
return true
end
--- Get the raw redis client
+-- Get the raw redis client (deprecated — prefer using proxy functions)
function redis_mod.client()
- ensure_connected()
- return client
+ logger.warn('redis.client() is deprecated — use redis proxy functions instead')
+ local conn = acquire()
+ return conn
end
-- Proxy common operations with auto-reconnect
function redis_mod.get(key)
- return safe_call(client.get, key)
+ return safe_call('get', key)
end
function redis_mod.set(key, value)
- return safe_call(client.set, key, value)
+ return safe_call('set', key, value)
end
function redis_mod.setex(key, ttl, value)
- return safe_call(client.setex, key, ttl, value)
+ return safe_call('setex', key, ttl, value)
end
function redis_mod.setnx(key, value)
- return safe_call(client.setnx, key, value)
+ return safe_call('setnx', key, value)
end
function redis_mod.del(key)
- return safe_call(client.del, key)
+ return safe_call('del', key)
end
function redis_mod.exists(key)
- return safe_call(client.exists, key)
+ return safe_call('exists', key)
end
function redis_mod.expire(key, ttl)
- return safe_call(client.expire, key, ttl)
+ return safe_call('expire', key, ttl)
end
function redis_mod.incr(key)
- return safe_call(client.incr, key)
+ return safe_call('incr', key)
end
function redis_mod.incrby(key, amount)
- return safe_call(client.incrby, key, amount)
+ return safe_call('incrby', key, amount)
end
function redis_mod.hget(key, field)
- return safe_call(client.hget, key, field)
+ return safe_call('hget', key, field)
end
function redis_mod.hset(key, field, value)
- return safe_call(client.hset, key, field, value)
+ return safe_call('hset', key, field, value)
end
function redis_mod.hdel(key, field)
- return safe_call(client.hdel, key, field)
+ return safe_call('hdel', key, field)
end
function redis_mod.hgetall(key)
- return safe_call(client.hgetall, key)
+ return safe_call('hgetall', key)
end
function redis_mod.hexists(key, field)
- return safe_call(client.hexists, key, field)
+ return safe_call('hexists', key, field)
end
function redis_mod.hincrby(key, field, increment)
- return safe_call(client.hincrby, key, field, increment)
+ return safe_call('hincrby', key, field, increment)
end
function redis_mod.sadd(key, value)
- return safe_call(client.sadd, key, value)
+ return safe_call('sadd', key, value)
end
function redis_mod.srem(key, value)
- return safe_call(client.srem, key, value)
+ return safe_call('srem', key, value)
end
function redis_mod.sismember(key, value)
- return safe_call(client.sismember, key, value)
+ return safe_call('sismember', key, value)
end
function redis_mod.smembers(key)
- return safe_call(client.smembers, key)
+ return safe_call('smembers', key)
end
-- List operations (used by AI plugin)
function redis_mod.rpush(key, value)
- return safe_call(client.rpush, key, value)
+ return safe_call('rpush', key, value)
end
function redis_mod.lrange(key, start, stop)
- return safe_call(client.lrange, key, start, stop)
+ return safe_call('lrange', key, start, stop)
end
function redis_mod.ltrim(key, start, stop)
- return safe_call(client.ltrim, key, start, stop)
+ return safe_call('ltrim', key, start, stop)
end
-- SCAN-based iteration — replaces all KEYS usage
-- Returns all keys matching pattern without blocking
function redis_mod.scan(pattern)
- if not ensure_connected() then
+ local conn = acquire()
+ if not conn then
return {}
end
local results = {}
local cursor = '0'
repeat
local ok, reply = pcall(function()
- return client:scan(cursor, { match = pattern, count = 100 })
+ return conn:scan(cursor, { match = pattern, count = 100 })
end)
- if not ok or not reply then break end
+ if not ok or not reply then
+ discard(conn)
+ return results
+ end
cursor = reply[1]
for _, key in ipairs(reply[2]) do
table.insert(results, key)
end
until cursor == '0'
+ release(conn)
return results
end
-- DEPRECATED: kept for compatibility but uses SCAN internally
function redis_mod.keys(pattern)
logger.warn('redis.keys() called — prefer redis.scan() to avoid blocking')
return redis_mod.scan(pattern)
end
-- Pipeline support: batch multiple commands and execute together
function redis_mod.pipeline(fn)
- if not ensure_connected() then
+ local conn = acquire()
+ if not conn then
return nil
end
- local pipeline = client:pipeline()
+ local pipeline = conn:pipeline()
fn(pipeline)
local ok, results = pcall(function()
return pipeline:execute()
end)
if not ok then
logger.error('Redis pipeline failed: %s', tostring(results))
+ discard(conn)
return nil
end
+ release(conn)
return results
end
function redis_mod.disconnect()
- if client then
- pcall(function() client:quit() end)
- client = nil
- logger.info('Disconnected from Redis')
+ for _, conn in ipairs(pool) do
+ pcall(function() conn:quit() end)
end
+ pool = {}
+ pool_semaphore = nil
+ logger.info('Disconnected from Redis (pool drained)')
end
return redis_mod
diff --git a/src/core/router.lua b/src/core/router.lua
index dba32b0..321828a 100644
--- a/src/core/router.lua
+++ b/src/core/router.lua
@@ -1,430 +1,410 @@
--[[
- mattata v2.0 - Event Router
+ mattata v2.1 - Event Router
Dispatches Telegram updates through middleware pipeline to plugins.
- Handles messages, callback queries, inline queries, and other events.
+ Uses copas coroutines via telegram-bot-lua's async system for concurrent
+ update processing — each update runs in its own coroutine.
]]
local router = {}
local json = require('dkjson')
-local socket = require('socket')
+local copas = require('copas')
local config = require('src.core.config')
local logger = require('src.core.logger')
local middleware_pipeline = require('src.core.middleware')
local session = require('src.core.session')
local permissions = require('src.core.permissions')
local i18n = require('src.core.i18n')
local tools
local api, loader, ctx_base
-- Import middleware modules
local mw_blocklist = require('src.middleware.blocklist')
local mw_rate_limit = require('src.middleware.rate_limit')
local mw_user_tracker = require('src.middleware.user_tracker')
local mw_language = require('src.middleware.language')
local mw_federation = require('src.middleware.federation')
local mw_captcha = require('src.middleware.captcha')
local mw_stats = require('src.middleware.stats')
function router.init(api_ref, tools_ref, loader_ref, ctx_base_ref)
api = api_ref
tools = tools_ref
loader = loader_ref
ctx_base = ctx_base_ref
-- Register middleware in order
middleware_pipeline.use(mw_blocklist)
middleware_pipeline.use(mw_rate_limit)
middleware_pipeline.use(mw_federation)
middleware_pipeline.use(mw_captcha)
middleware_pipeline.use(mw_user_tracker)
middleware_pipeline.use(mw_language)
middleware_pipeline.use(mw_stats)
end
-- Build a fresh context for each update
-- Admin check is lazy — only resolved when ctx:check_admin() is called
local function build_ctx(message)
local ctx = {}
for k, v in pairs(ctx_base) do
ctx[k] = v
end
ctx.is_group = message.chat and message.chat.type ~= 'private'
ctx.is_supergroup = message.chat and message.chat.type == 'supergroup'
ctx.is_private = message.chat and message.chat.type == 'private'
ctx.is_global_admin = message.from and permissions.is_global_admin(message.from.id) or false
-- Lazy admin check: only makes API call when first accessed
-- Caches result for the lifetime of this context
local admin_resolved = false
local admin_value = false
ctx.is_admin = false -- default for non-admin reads
function ctx:check_admin()
if admin_resolved then
return admin_value
end
admin_resolved = true
if ctx.is_global_admin then
admin_value = true
elseif ctx.is_group and message.from then
admin_value = permissions.is_group_admin(api, message.chat.id, message.from.id)
end
ctx.is_admin = admin_value
return admin_value
end
-- For backward compat: admin plugins that check ctx.is_admin will still
-- need to call ctx:check_admin() first. The router does this for admin_only plugins.
ctx.is_mod = false
return ctx
end
-- Sort/normalise a message object (ported from v1 mattata.sort_message)
local function sort_message(message)
message.text = message.text or message.caption or ''
-- Normalise /command_arg to /command arg
message.text = message.text:gsub('^(/[%a]+)_', '%1 ')
-- Deep-link support
if message.text:match('^[/!#]start .-$') then
message.text = '/' .. message.text:match('^[/!#]start (.-)$')
end
-- Shorthand reply alias
if message.reply_to_message then
message.reply = message.reply_to_message
message.reply_to_message = nil
end
-- Normalise language code
if message.from and message.from.language_code then
local lc = message.from.language_code:lower():gsub('%-', '_')
if #lc == 2 and lc ~= 'en' then
lc = lc .. '_' .. lc
elseif #lc == 2 or lc == 'root' then
lc = 'en_us'
end
message.from.language_code = lc
end
-- Detect media
message.is_media = message.photo or message.video or message.audio or message.voice
or message.document or message.sticker or message.animation or message.video_note or false
-- Detect service messages
message.is_service_message = (message.new_chat_members or message.left_chat_member
or message.new_chat_title or message.new_chat_photo or message.pinned_message
or message.group_chat_created or message.supergroup_chat_created) and true or false
-- Entity-based text mentions -> ID substitution
if message.entities then
for _, entity in ipairs(message.entities) do
if entity.type == 'text_mention' and entity.user then
local name = message.text:sub(entity.offset + 1, entity.offset + entity.length)
message.text = message.text:gsub(name, tostring(entity.user.id), 1)
end
end
end
-- Process caption entities as entities
if message.caption_entities then
message.entities = message.caption_entities
message.caption_entities = nil
end
-- Sort reply recursively
if message.reply then
message.reply = sort_message(message.reply)
end
return message
end
-- Extract command from message text
local function extract_command(text, bot_username)
if not text then return nil, nil end
local cmd, args = text:match('^[/!#]([%w_]+)@?' .. (bot_username or '') .. '%s*(.*)')
if not cmd then
cmd, args = text:match('^[/!#]([%w_]+)%s*(.*)')
end
if cmd then
cmd = cmd:lower()
args = args ~= '' and args or nil
end
return cmd, args
end
-- Resolve aliases for a chat (with Redis caching)
local function resolve_alias(message, redis_mod)
if not message.text:match('^[/!#][%w_]+') then return message end
if not message.chat or message.chat.type == 'private' then return message end
local command, rest = message.text:lower():match('^[/!#]([%w_]+)(.*)')
if not command then return message end
-- Cache alias lookups with TTL instead of hgetall on every message
local cache_key = 'cache:aliases:' .. message.chat.id
local cached_aliases = redis_mod.get(cache_key)
local aliases
if cached_aliases then
local ok, decoded = pcall(json.decode, cached_aliases)
if ok and decoded then
aliases = decoded
end
end
if not aliases then
aliases = redis_mod.hgetall('chat:' .. message.chat.id .. ':aliases')
if type(aliases) == 'table' then
pcall(function()
redis_mod.setex(cache_key, 300, json.encode(aliases))
end)
end
end
if type(aliases) == 'table' then
for alias, original in pairs(aliases) do
if command == alias then
message.text = '/' .. original .. (rest or '')
message.is_alias = true
break
end
end
end
return message
end
-- Process action state (multi-step commands)
-- Fixed: save message_id before nil'ing message.reply
local function process_action(message, ctx)
if message.text and message.chat and message.reply
and message.reply.from and message.reply.from.id == api.info.id then
local reply_message_id = message.reply.message_id
local action = session.get_action(message.chat.id, reply_message_id)
if action then
message.text = action .. ' ' .. message.text
message.reply = nil
session.del_action(message.chat.id, reply_message_id)
end
end
return message
end
-- Handle a message update
local function on_message(message)
-- Validate
if not message or not message.from then return end
if message.date and message.date < os.time() - 10 then return end
-- Sort/normalise
message = sort_message(message)
message = process_action(message, ctx_base)
message = resolve_alias(message, ctx_base.redis)
-- Build context and run middleware
local ctx = build_ctx(message)
local should_continue
ctx, should_continue = middleware_pipeline.run(ctx, message)
if not should_continue then return end
-- Dispatch command to matching plugin
local cmd, args = extract_command(message.text, api.info.username)
- local command_handled = false
if cmd then
local plugin = loader.get_by_command(cmd)
if plugin and plugin.on_message then
if not session.is_plugin_disabled(message.chat.id, plugin.name) or loader.is_permanent(plugin.name) then
-- Check permission requirements
if plugin.global_admin_only and not ctx.is_global_admin then
return
end
-- Resolve admin status only for admin_only plugins (lazy check)
if plugin.admin_only then
ctx:check_admin()
if not ctx.is_admin and not ctx.is_global_admin then
return api.send_message(message.chat.id, ctx.lang and ctx.lang.errors and ctx.lang.errors.admin or 'You need to be an admin to use this command.')
end
end
if plugin.group_only and ctx.is_private then
return api.send_message(message.chat.id, ctx.lang and ctx.lang.errors and ctx.lang.errors.supergroup or 'This command can only be used in groups.')
end
message.command = cmd
message.args = args
local ok, err = pcall(plugin.on_message, api, message, ctx)
if not ok then
logger.error('Plugin %s.on_message error: %s', plugin.name, tostring(err))
if config.log_chat() then
api.send_message(config.log_chat(), string.format(
'<pre>[%s] %s error:\n%s\nFrom: %s\nText: %s</pre>',
os.date('%X'), plugin.name,
tools.escape_html(tostring(err)),
message.from.id,
tools.escape_html(message.text or '')
), 'html')
end
end
- command_handled = true
end
end
end
-- Run passive handlers (on_new_message) for all non-disabled plugins
for _, plugin in ipairs(loader.get_plugins()) do
if plugin.on_new_message and not session.is_plugin_disabled(message.chat.id, plugin.name) then
local ok, err = pcall(plugin.on_new_message, api, message, ctx)
if not ok then
logger.error('Plugin %s.on_new_message error: %s', plugin.name, tostring(err))
end
end
-- Handle member join events
if message.new_chat_members and plugin.on_member_join then
local ok, err = pcall(plugin.on_member_join, api, message, ctx)
if not ok then
logger.error('Plugin %s.on_member_join error: %s', plugin.name, tostring(err))
end
end
end
end
-- Handle callback query (routed through middleware for blocklist + rate limit)
local function on_callback_query(callback_query)
if not callback_query or not callback_query.from then return end
if not callback_query.data then return end
local message = callback_query.message or {
chat = {},
message_id = callback_query.inline_message_id,
from = callback_query.from
}
-- Parse plugin_name:data format
local plugin_name, cb_data = callback_query.data:match('^(.-):(.*)$')
if not plugin_name then return end
local plugin = loader.get_by_name(plugin_name)
if not plugin or not plugin.on_callback_query then return end
callback_query.data = cb_data
-- Build context and run basic middleware (blocklist + rate limit)
local ctx = build_ctx(message)
-- Check blocklist for callback user
if session.is_globally_blocklisted(callback_query.from.id) then
return
end
-- Load language for callback user
local lang_code = session.get_setting(callback_query.from.id, 'language') or 'en_gb'
ctx.lang = i18n.get(lang_code)
local ok, err = pcall(plugin.on_callback_query, api, callback_query, message, ctx)
if not ok then
logger.error('Plugin %s.on_callback_query error: %s', plugin_name, tostring(err))
end
end
-- Handle inline query
local function on_inline_query(inline_query)
if not inline_query or not inline_query.from then return end
if session.is_globally_blocklisted(inline_query.from.id) then return end
local ctx = build_ctx({ from = inline_query.from, chat = { type = 'private' } })
local lang_code = session.get_setting(inline_query.from.id, 'language') or 'en_gb'
ctx.lang = i18n.get(lang_code)
for _, plugin in ipairs(loader.get_plugins()) do
if plugin.on_inline_query then
local ok, err = pcall(plugin.on_inline_query, api, inline_query, ctx)
if not ok then
logger.error('Plugin %s.on_inline_query error: %s', plugin.name, tostring(err))
end
end
end
end
--- Run cron jobs asynchronously in coroutines
-local function run_cron_async()
- for _, plugin in ipairs(loader.get_plugins()) do
- if plugin.cron then
- local co = coroutine.create(function()
- local ok, err = pcall(plugin.cron, api, ctx_base)
- if not ok then
- logger.error('Plugin %s cron error: %s', plugin.name, tostring(err))
- end
- end)
- coroutine.resume(co)
- end
- end
-end
-
--- Main polling loop
+-- Concurrent polling loop using telegram-bot-lua's async system
function router.run()
- local last_update = 0
- local last_cron = os.date('%M')
- local last_stats_flush = 0
local polling = config.polling()
- while true do
- local success = api.get_updates(
- polling.timeout,
- last_update + 1,
- polling.limit,
- json.encode({
- 'message', 'edited_message', 'callback_query', 'inline_query',
- 'chat_join_request', 'chat_member', 'my_chat_member',
- 'message_reaction'
- })
- )
-
- if success and success.result then
- for _, update in ipairs(success.result) do
- last_update = update.update_id
- local start_time = socket.gettime()
-
- if update.message or update.edited_message then
- local msg = update.message or update.edited_message
- if update.edited_message then
- msg.is_edited = true
- end
- local ok, err = pcall(on_message, msg)
- if not ok then
- logger.error('on_message error: %s', tostring(err))
- end
- elseif update.callback_query then
- local ok, err = pcall(on_callback_query, update.callback_query)
- if not ok then
- logger.error('on_callback_query error: %s', tostring(err))
- end
- elseif update.inline_query then
- local ok, err = pcall(on_inline_query, update.inline_query)
- if not ok then
- logger.error('on_inline_query error: %s', tostring(err))
- end
- end
+ -- Register telegram-bot-lua handler callbacks
+ -- api.process_update() dispatches to these inside per-update copas coroutines
+ api.on_message = function(msg)
+ local ok, err = pcall(on_message, msg)
+ if not ok then logger.error('on_message error: %s', tostring(err)) end
+ end
- if config.debug() then
- logger.debug('Update #%d processed in %.3fs', update.update_id, socket.gettime() - start_time)
- end
- end
- else
- logger.error('Failed to retrieve updates from Telegram API')
- end
+ api.on_edited_message = function(msg)
+ msg.is_edited = true
+ local ok, err = pcall(on_message, msg)
+ if not ok then logger.error('on_edited_message error: %s', tostring(err)) end
+ end
- -- Minutely cron jobs (async via coroutines)
- if last_cron ~= os.date('%M') then
- last_cron = os.date('%M')
- run_cron_async()
- end
+ api.on_callback_query = function(cb)
+ local ok, err = pcall(on_callback_query, cb)
+ if not ok then logger.error('on_callback_query error: %s', tostring(err)) end
+ end
- -- Flush stats counters to PostgreSQL every 5 minutes
- local now = os.time()
- if now - last_stats_flush >= 300 then
- last_stats_flush = now
- local co = coroutine.create(function()
- local ok, err = pcall(mw_stats.flush, ctx_base.db, ctx_base.redis)
- if not ok then
- logger.error('Stats flush error: %s', tostring(err))
+ api.on_inline_query = function(iq)
+ local ok, err = pcall(on_inline_query, iq)
+ if not ok then logger.error('on_inline_query error: %s', tostring(err)) end
+ end
+
+ -- Cron: copas background thread, runs every 60s
+ copas.addthread(function()
+ while true do
+ copas.pause(60)
+ for _, plugin in ipairs(loader.get_plugins()) do
+ if plugin.cron then
+ copas.addthread(function()
+ local ok, err = pcall(plugin.cron, api, ctx_base)
+ if not ok then
+ logger.error('Plugin %s cron error: %s', plugin.name, tostring(err))
+ end
+ end)
end
- end)
- coroutine.resume(co)
+ end
end
- end
+ end)
+
+ -- Stats flush: copas background thread, runs every 300s
+ copas.addthread(function()
+ while true do
+ copas.pause(300)
+ local ok, err = pcall(mw_stats.flush, ctx_base.db, ctx_base.redis)
+ if not ok then logger.error('Stats flush error: %s', tostring(err)) end
+ end
+ end)
+
+ -- Start concurrent polling loop
+ -- api.run() -> api.async.run() which:
+ -- 1. Swaps api.request to copas-based api.async.request
+ -- 2. Spawns polling coroutine calling get_updates in a loop
+ -- 3. For each update, spawns NEW coroutine -> api.process_update -> handlers above
+ -- 4. Calls copas.loop()
+ api.run({
+ timeout = polling.timeout,
+ limit = polling.limit,
+ allowed_updates = {
+ 'message', 'edited_message', 'callback_query', 'inline_query',
+ 'chat_join_request', 'chat_member', 'my_chat_member',
+ 'message_reaction'
+ }
+ })
end
return router

File Metadata

Mime Type
text/x-diff
Expires
Sun, May 17, 9:13 AM (1 d, 12 h)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
63003
Default Alt Text
(77 KB)

Event Timeline