Page Menu
Home
Phabricator (Chris)
Search
Configure Global Search
Log In
Files
F119167
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Flag For Later
Award Token
Authored By
Unknown
Size
77 KB
Referenced Files
None
Subscribers
None
View Options
diff --git a/main.lua b/main.lua
index e504c0e..703464d 100644
--- a/main.lua
+++ b/main.lua
@@ -1,94 +1,94 @@
--[[
_ _ _
_ __ ___ __ _| |_| |_ __ _| |_ __ _
| '_ ` _ \ / _` | __| __/ _` | __/ _` |
| | | | | | (_| | |_| || (_| | || (_| |
|_| |_| |_|\__,_|\__|\__\__,_|\__\__,_|
- v2.0
+ v2.1
Copyright 2020-2026 Matthew Hesketh <matthew@matthewhesketh.com>
See LICENSE for details
]]
local config = require('src.core.config')
local logger = require('src.core.logger')
local database = require('src.core.database')
local redis = require('src.core.redis')
local session = require('src.core.session')
local i18n = require('src.core.i18n')
local loader = require('src.core.loader')
local router = require('src.core.router')
local migrations = require('src.db.init')
-- 1. Load configuration
config.load('.env')
logger.init()
logger.info('mattata v%s starting...', config.VERSION)
-- 2. Validate required config
assert(config.bot_token(), 'BOT_TOKEN is required. Set it in .env or as an environment variable.')
-- 3. Configure telegram-bot-lua
-local api = require('telegram-bot-lua.core').configure(config.bot_token())
+local api = require('telegram-bot-lua').configure(config.bot_token())
local tools = require('telegram-bot-lua.tools')
logger.info('Bot: @%s (%s) [%d]', api.info.username, api.info.first_name, api.info.id)
-- 4. Connect to PostgreSQL
local db_ok, db_err = database.connect()
if not db_ok then
logger.error('Cannot start without PostgreSQL: %s', tostring(db_err))
os.exit(1)
end
-- 5. Run database migrations
migrations.run(database)
-- 6. Connect to Redis
local redis_ok, redis_err = redis.connect()
if not redis_ok then
logger.error('Cannot start without Redis: %s', tostring(redis_err))
os.exit(1)
end
session.init(redis)
-- 7. Load languages
i18n.init()
-- 8. Load all plugins
loader.init(api, database, redis)
-- 9. Build context factory and start router
local ctx_base = {
api = api,
tools = tools,
db = database,
redis = redis,
session = session,
config = config,
i18n = i18n,
permissions = require('src.core.permissions'),
logger = logger
}
router.init(api, tools, loader, ctx_base)
-- 10. Notify admins
local info_msg = string.format(
'<pre>mattata v%s connected!\n\n Username: @%s\n Name: %s\n ID: %d\n Plugins: %d</pre>',
config.VERSION,
tools.escape_html(api.info.username),
tools.escape_html(api.info.first_name),
api.info.id,
loader.count()
)
if config.log_chat() then
api.send_message(config.log_chat(), info_msg, 'html')
end
for _, admin_id in ipairs(config.bot_admins()) do
api.send_message(admin_id, info_msg, 'html')
end
-- 11. Start the bot
logger.info('Starting main loop...')
router.run()
diff --git a/spec/core/concurrency_spec.lua b/spec/core/concurrency_spec.lua
new file mode 100644
index 0000000..f743b53
--- /dev/null
+++ b/spec/core/concurrency_spec.lua
@@ -0,0 +1,168 @@
+--[[
+ Tests for v2.1 concurrency-safe pool mechanics.
+ Tests semaphore guards on database and Redis pools,
+ method-name-string safe_call, and graceful fallback
+ when semaphore is nil (test/mock environment).
+]]
+
+describe('concurrency', function()
+
+ describe('database pool safety', function()
+ local mock_db = require('spec.helpers.mock_db')
+
+ it('should work without semaphore (test environment fallback)', function()
+ local db = mock_db.new()
+ -- mock_db never calls connect(), so semaphore is nil
+ -- All operations should still work identically to v2.0
+ local result = db.query('SELECT 1')
+ assert.are.same({}, result)
+ db.set_next_result({ { id = 1 } })
+ result = db.query('SELECT * FROM users')
+ assert.are.same({ { id = 1 } }, result)
+ end)
+
+ it('should execute queries and return results without semaphore', function()
+ local db = mock_db.new()
+ db.set_next_result({ { count = 42 } })
+ local result = db.execute('SELECT COUNT(*) FROM messages', {})
+ assert.are.same({ { count = 42 } }, result)
+ end)
+
+ it('should handle transactions without semaphore', function()
+ local db = mock_db.new()
+ local called = false
+ db.transaction(function(query, execute)
+ called = true
+ query('SELECT 1')
+ end)
+ assert.is_true(called)
+ end)
+
+ it('pool_stats should report available and max_size', function()
+ local db = mock_db.new()
+ local stats = db.pool_stats()
+ assert.is_number(stats.available)
+ assert.is_number(stats.max_size)
+ end)
+ end)
+
+ describe('redis pool safety', function()
+ local mock_redis = require('spec.helpers.mock_redis')
+
+ it('should work without semaphore (test environment fallback)', function()
+ local redis = mock_redis.new()
+ -- mock_redis never calls connect() with real config, so no semaphore
+ redis.set('key', 'value')
+ assert.are.equal('value', redis.get('key'))
+ end)
+
+ it('safe_call with method name strings should work via proxy functions', function()
+ local redis = mock_redis.new()
+ -- All proxy functions in the real module use safe_call('method_name', ...)
+ -- The mock simulates this directly
+ redis.set('test_key', 'test_value')
+ assert.are.equal('test_value', redis.get('test_key'))
+ assert.is_true(redis.has_command('set'))
+ assert.is_true(redis.has_command('get'))
+ end)
+
+ it('hash operations should work through proxy', function()
+ local redis = mock_redis.new()
+ redis.hset('hash:1', 'field1', 'value1')
+ assert.are.equal('value1', redis.hget('hash:1', 'field1'))
+ local all = redis.hgetall('hash:1')
+ assert.are.equal('value1', all['field1'])
+ end)
+
+ it('set operations should work through proxy', function()
+ local redis = mock_redis.new()
+ redis.sadd('set:1', 'member1')
+ assert.are.equal(1, redis.sismember('set:1', 'member1'))
+ assert.are.equal(0, redis.sismember('set:1', 'nonexistent'))
+ end)
+
+ it('list operations should work through proxy', function()
+ local redis = mock_redis.new()
+ redis.rpush('list:1', 'a')
+ redis.rpush('list:1', 'b')
+ local items = redis.lrange('list:1', 0, -1)
+ assert.are.equal(2, #items)
+ assert.are.equal('a', items[1])
+ end)
+
+ it('scan should work through pool', function()
+ local redis = mock_redis.new()
+ redis.set('prefix:1', 'a')
+ redis.set('prefix:2', 'b')
+ redis.set('other:1', 'c')
+ local results = redis.scan('prefix:*')
+ assert.are.equal(2, #results)
+ end)
+
+ it('client() should still return a usable object', function()
+ local redis = mock_redis.new()
+ local c = redis.client()
+ assert.is_not_nil(c)
+ end)
+ end)
+
+ describe('mock_api async stubs', function()
+ local mock_api = require('spec.helpers.mock_api')
+
+ it('should have handler stubs', function()
+ local api = mock_api.new()
+ assert.is_function(api.on_message)
+ assert.is_function(api.on_edited_message)
+ assert.is_function(api.on_callback_query)
+ assert.is_function(api.on_inline_query)
+ end)
+
+ it('should have async module stubs', function()
+ local api = mock_api.new()
+ assert.is_table(api.async)
+ assert.is_function(api.async.run)
+ assert.is_function(api.async.stop)
+ assert.is_function(api.async.all)
+ assert.is_function(api.async.spawn)
+ assert.is_function(api.async.sleep)
+ assert.is_function(api.async.is_running)
+ end)
+
+ it('should have api.run stub', function()
+ local api = mock_api.new()
+ assert.is_function(api.run)
+ -- Should not error when called
+ assert.has_no.errors(function()
+ api.run({ timeout = 60, limit = 100 })
+ end)
+ end)
+
+ it('should have process_update stub', function()
+ local api = mock_api.new()
+ assert.is_function(api.process_update)
+ assert.has_no.errors(function()
+ api.process_update({ update_id = 1, message = {} })
+ end)
+ end)
+
+ it('async.is_running should return false in test environment', function()
+ local api = mock_api.new()
+ assert.is_false(api.async.is_running())
+ end)
+
+ it('async.spawn should execute the function', function()
+ local api = mock_api.new()
+ local called = false
+ api.async.spawn(function() called = true end)
+ assert.is_true(called)
+ end)
+
+ it('handler stubs should be overwritable', function()
+ local api = mock_api.new()
+ local msg_received = nil
+ api.on_message = function(msg) msg_received = msg end
+ api.on_message({ text = 'hello' })
+ assert.are.same({ text = 'hello' }, msg_received)
+ end)
+ end)
+end)
diff --git a/spec/core/config_spec.lua b/spec/core/config_spec.lua
index a1c8ddc..9f30156 100644
--- a/spec/core/config_spec.lua
+++ b/spec/core/config_spec.lua
@@ -1,364 +1,364 @@
--[[
Tests for src/core/config.lua
Config module: loading .env, get/set values, typed access (number, boolean, list).
]]
describe('core.config', function()
local config
local tmpfile
-- Write a temporary .env file for testing
local function write_env(content)
tmpfile = os.tmpname()
local f = io.open(tmpfile, 'w')
f:write(content)
f:close()
return tmpfile
end
before_each(function()
-- Clear the cached module so each test gets a fresh config
package.loaded['src.core.config'] = nil
config = require('src.core.config')
end)
after_each(function()
if tmpfile then
os.remove(tmpfile)
tmpfile = nil
end
end)
describe('load()', function()
it('should load a valid .env file', function()
local path = write_env('FOO=bar\nBAZ=qux\n')
config.load(path)
assert.are.equal('bar', config.get('FOO'))
assert.are.equal('qux', config.get('BAZ'))
end)
it('should not error on missing .env file', function()
assert.has_no.errors(function()
config.load('/tmp/nonexistent_env_file_' .. os.time())
end)
end)
it('should ignore empty lines', function()
local path = write_env('FOO=bar\n\n\nBAZ=qux\n')
config.load(path)
assert.are.equal('bar', config.get('FOO'))
assert.are.equal('qux', config.get('BAZ'))
end)
it('should ignore comment lines', function()
local path = write_env('# This is a comment\nFOO=bar\n# Another comment\n')
config.load(path)
assert.are.equal('bar', config.get('FOO'))
end)
it('should strip surrounding double quotes from values', function()
local path = write_env('FOO="hello world"\n')
config.load(path)
assert.are.equal('hello world', config.get('FOO'))
end)
it('should strip surrounding single quotes from values', function()
local path = write_env("FOO='hello world'\n")
config.load(path)
assert.are.equal('hello world', config.get('FOO'))
end)
it('should strip inline comments from unquoted values', function()
local path = write_env('FOO=bar # this is a comment\n')
config.load(path)
assert.are.equal('bar', config.get('FOO'))
end)
it('should handle values with equals signs', function()
local path = write_env('FOO=bar=baz\n')
config.load(path)
assert.are.equal('bar=baz', config.get('FOO'))
end)
it('should trim whitespace around keys and values', function()
local path = write_env(' FOO = bar \n')
config.load(path)
assert.are.equal('bar', config.get('FOO'))
end)
end)
describe('get()', function()
it('should return the value for a known key', function()
local path = write_env('MY_KEY=my_value\n')
config.load(path)
assert.are.equal('my_value', config.get('MY_KEY'))
end)
it('should return default when key is missing', function()
local path = write_env('OTHER_KEY=other\n')
config.load(path)
assert.are.equal('fallback', config.get('NONEXISTENT', 'fallback'))
end)
it('should return nil when key is missing and no default', function()
local path = write_env('')
config.load(path)
assert.is_nil(config.get('NONEXISTENT'))
end)
it('should fall back to os.getenv for empty .env values', function()
local path = write_env('EMPTY_KEY=\n')
config.load(path)
-- This will either return nil or whatever the OS env has
local result = config.get('EMPTY_KEY', 'default_val')
assert.is_not_nil(result)
end)
it('should auto-load .env if not explicitly loaded', function()
-- Just calling get() without load() should not error
assert.has_no.errors(function()
config.get('ANYTHING')
end)
end)
end)
describe('get_number()', function()
it('should return a number for numeric values', function()
local path = write_env('PORT=8080\n')
config.load(path)
assert.are.equal(8080, config.get_number('PORT'))
end)
it('should return default for non-numeric values', function()
local path = write_env('PORT=abc\n')
config.load(path)
assert.are.equal(3000, config.get_number('PORT', 3000))
end)
it('should return default when key is missing', function()
local path = write_env('')
config.load(path)
assert.are.equal(5432, config.get_number('DB_PORT', 5432))
end)
it('should return nil when key is missing and no default', function()
local path = write_env('')
config.load(path)
assert.is_nil(config.get_number('MISSING'))
end)
it('should handle float values', function()
local path = write_env('RATE=1.5\n')
config.load(path)
assert.are.equal(1.5, config.get_number('RATE'))
end)
it('should handle negative numbers', function()
local path = write_env('OFFSET=-10\n')
config.load(path)
assert.are.equal(-10, config.get_number('OFFSET'))
end)
end)
describe('is_enabled()', function()
it('should return true for "true"', function()
local path = write_env('FLAG=true\n')
config.load(path)
assert.is_true(config.is_enabled('FLAG'))
end)
it('should return true for "1"', function()
local path = write_env('FLAG=1\n')
config.load(path)
assert.is_true(config.is_enabled('FLAG'))
end)
it('should return true for "yes"', function()
local path = write_env('FLAG=yes\n')
config.load(path)
assert.is_true(config.is_enabled('FLAG'))
end)
it('should return true case-insensitively for "TRUE"', function()
local path = write_env('FLAG=TRUE\n')
config.load(path)
assert.is_true(config.is_enabled('FLAG'))
end)
it('should return true case-insensitively for "Yes"', function()
local path = write_env('FLAG=Yes\n')
config.load(path)
assert.is_true(config.is_enabled('FLAG'))
end)
it('should return false for "false"', function()
local path = write_env('FLAG=false\n')
config.load(path)
assert.is_false(config.is_enabled('FLAG'))
end)
it('should return false for "0"', function()
local path = write_env('FLAG=0\n')
config.load(path)
assert.is_false(config.is_enabled('FLAG'))
end)
it('should return false for "no"', function()
local path = write_env('FLAG=no\n')
config.load(path)
assert.is_false(config.is_enabled('FLAG'))
end)
it('should return false for missing key', function()
local path = write_env('')
config.load(path)
assert.is_false(config.is_enabled('MISSING'))
end)
it('should return false for arbitrary string', function()
local path = write_env('FLAG=maybe\n')
config.load(path)
assert.is_false(config.is_enabled('FLAG'))
end)
end)
describe('get_list()', function()
it('should split comma-separated values into a table', function()
local path = write_env('ITEMS=a,b,c\n')
config.load(path)
local list = config.get_list('ITEMS')
assert.are.equal(3, #list)
assert.are.equal('a', list[1])
assert.are.equal('b', list[2])
assert.are.equal('c', list[3])
end)
it('should trim whitespace around items', function()
local path = write_env('ITEMS= a , b , c \n')
config.load(path)
local list = config.get_list('ITEMS')
assert.are.equal(3, #list)
assert.are.equal('a', list[1])
assert.are.equal('b', list[2])
assert.are.equal('c', list[3])
end)
it('should convert numeric items to numbers', function()
local path = write_env('IDS=100,200,300\n')
config.load(path)
local list = config.get_list('IDS')
assert.are.equal(3, #list)
assert.are.equal(100, list[1])
assert.are.equal(200, list[2])
assert.are.equal(300, list[3])
end)
it('should return empty table for missing key', function()
local path = write_env('')
config.load(path)
local list = config.get_list('MISSING')
assert.are.same({}, list)
end)
it('should return empty table for empty value', function()
local path = write_env('ITEMS=\n')
config.load(path)
local list = config.get_list('ITEMS')
assert.are.same({}, list)
end)
it('should handle single-item list', function()
local path = write_env('ITEMS=only\n')
config.load(path)
local list = config.get_list('ITEMS')
assert.are.equal(1, #list)
assert.are.equal('only', list[1])
end)
it('should handle mixed numeric and string items', function()
local path = write_env('ITEMS=100,hello,300\n')
config.load(path)
local list = config.get_list('ITEMS')
assert.are.equal(100, list[1])
assert.are.equal('hello', list[2])
assert.are.equal(300, list[3])
end)
end)
describe('convenience accessors', function()
it('should return bot_token', function()
local path = write_env('BOT_TOKEN=12345:ABCDEF\n')
config.load(path)
assert.are.equal('12345:ABCDEF', config.bot_token())
end)
it('should return bot_admins as a list', function()
local path = write_env('BOT_ADMINS=221714512,123456\n')
config.load(path)
local admins = config.bot_admins()
assert.are.equal(2, #admins)
assert.are.equal(221714512, admins[1])
end)
it('should return bot_name with default', function()
local path = write_env('')
config.load(path)
assert.are.equal('mattata', config.bot_name())
end)
it('should return database config with defaults', function()
local path = write_env('')
config.load(path)
local db = config.database()
assert.are.equal('postgres', db.host)
assert.are.equal(5432, db.port)
assert.are.equal('mattata', db.database)
end)
it('should return redis config with defaults', function()
local path = write_env('')
config.load(path)
local rc = config.redis_config()
assert.are.equal('redis', rc.host)
assert.are.equal(6379, rc.port)
assert.are.equal(0, rc.db)
end)
it('should return polling config with defaults', function()
local path = write_env('')
config.load(path)
local p = config.polling()
assert.are.equal(60, p.timeout)
assert.are.equal(100, p.limit)
end)
it('should return webhook config', function()
local path = write_env('WEBHOOK_ENABLED=true\nWEBHOOK_URL=https://example.com\nWEBHOOK_PORT=8443\n')
config.load(path)
local wh = config.webhook()
assert.is_true(wh.enabled)
assert.are.equal('https://example.com', wh.url)
assert.are.equal(8443, wh.port)
end)
it('should return debug status', function()
local path = write_env('DEBUG=true\n')
config.load(path)
assert.is_true(config.debug())
end)
it('should return ai config with defaults', function()
local path = write_env('')
config.load(path)
local ai = config.ai()
assert.is_false(ai.enabled)
assert.are.equal('gpt-4o', ai.openai_model)
end)
end)
describe('VERSION', function()
- it('should be 2.0', function()
- assert.are.equal('2.0', config.VERSION)
+ it('should be 2.1', function()
+ assert.are.equal('2.1', config.VERSION)
end)
end)
end)
diff --git a/spec/helpers/mock_api.lua b/spec/helpers/mock_api.lua
index bac61ec..b7552a0 100644
--- a/spec/helpers/mock_api.lua
+++ b/spec/helpers/mock_api.lua
@@ -1,216 +1,243 @@
--[[
- mattata v2.0 - Mock Telegram Bot API
+ mattata v2.1 - Mock Telegram Bot API
Records all calls and returns configurable responses for testing.
+ Includes async/handler stubs for copas-based concurrency support.
]]
local mock_api = {}
function mock_api.new()
local api = {
info = { id = 123456789, username = 'testbot', first_name = 'Test Bot' },
calls = {},
}
local custom_handlers = {}
local function record(method, ...)
table.insert(api.calls, { method = method, args = {...} })
end
function api.send_message(chat_id, text, parse_mode, ...)
record('send_message', chat_id, text, parse_mode, ...)
return { ok = true, result = { message_id = #api.calls, chat = { id = chat_id } } }
end
function api.get_chat_member(chat_id, user_id)
record('get_chat_member', chat_id, user_id)
if custom_handlers.get_chat_member then
return custom_handlers.get_chat_member(chat_id, user_id)
end
-- Default: regular member
return { ok = true, result = { status = 'member', user = { id = user_id } } }
end
function api.ban_chat_member(chat_id, user_id, until_date)
record('ban_chat_member', chat_id, user_id, until_date)
return { ok = true, result = true }
end
function api.unban_chat_member(chat_id, user_id)
record('unban_chat_member', chat_id, user_id)
return { ok = true, result = true }
end
function api.restrict_chat_member(chat_id, user_id, perms_or_until, maybe_perms)
record('restrict_chat_member', chat_id, user_id, perms_or_until, maybe_perms)
return { ok = true, result = true }
end
function api.delete_message(chat_id, message_id)
record('delete_message', chat_id, message_id)
return { ok = true, result = true }
end
function api.pin_chat_message(chat_id, message_id, disable_notification)
record('pin_chat_message', chat_id, message_id, disable_notification)
return { ok = true, result = true }
end
function api.unpin_chat_message(chat_id, message_id)
record('unpin_chat_message', chat_id, message_id)
return { ok = true, result = true }
end
function api.get_chat(chat_id)
record('get_chat', chat_id)
return { ok = true, result = { id = chat_id, first_name = 'Test User' } }
end
function api.edit_message_text(chat_id, message_id, text, parse_mode, ...)
record('edit_message_text', chat_id, message_id, text, parse_mode, ...)
return { ok = true, result = { message_id = message_id } }
end
function api.edit_message_reply_markup(chat_id, message_id, inline_message_id, keyboard)
record('edit_message_reply_markup', chat_id, message_id, inline_message_id, keyboard)
return { ok = true, result = { message_id = message_id } }
end
function api.answer_callback_query(callback_id, text)
record('answer_callback_query', callback_id, text)
return { ok = true }
end
function api.get_updates(timeout, offset, limit, allowed)
record('get_updates', timeout, offset, limit, allowed)
return { ok = true, result = {} }
end
function api.leave_chat(chat_id)
record('leave_chat', chat_id)
return { ok = true, result = true }
end
function api.inline_keyboard()
local kb = {}
function kb:row(...)
return self
end
return kb
end
function api.row()
local r = {}
function r:callback_data_button(text, data)
return self
end
function r:url_button(text, url)
return self
end
return r
end
-- Helper to set custom get_chat_member behavior
function api.set_admin(chat_id, user_id)
local original_handler = custom_handlers.get_chat_member
custom_handlers.get_chat_member = function(cid, uid)
if cid == chat_id and uid == user_id then
return {
ok = true,
result = {
status = 'administrator',
user = { id = uid },
can_restrict_members = true,
can_delete_messages = true,
can_pin_messages = true,
can_promote_members = true,
can_invite_users = true,
}
}
end
if original_handler then
return original_handler(cid, uid)
end
return { ok = true, result = { status = 'member', user = { id = uid } } }
end
end
-- Helper to set the bot as an admin with specified permissions
function api.set_bot_admin(chat_id, perms)
perms = perms or {}
local original_handler = custom_handlers.get_chat_member
custom_handlers.get_chat_member = function(cid, uid)
if cid == chat_id and uid == api.info.id then
return {
ok = true,
result = {
status = 'administrator',
user = { id = uid },
can_restrict_members = perms.can_restrict_members or false,
can_delete_messages = perms.can_delete_messages or false,
can_pin_messages = perms.can_pin_messages or false,
can_promote_members = perms.can_promote_members or false,
can_invite_users = perms.can_invite_users or false,
}
}
end
if original_handler then
return original_handler(cid, uid)
end
return { ok = true, result = { status = 'member', user = { id = uid } } }
end
end
function api.set_creator(chat_id, user_id)
local original_handler = custom_handlers.get_chat_member
custom_handlers.get_chat_member = function(cid, uid)
if cid == chat_id and uid == user_id then
return {
ok = true,
result = {
status = 'creator',
user = { id = uid },
}
}
end
if original_handler then
return original_handler(cid, uid)
end
return { ok = true, result = { status = 'member', user = { id = uid } } }
end
end
+ -- Handler stubs (overwritten by router.run() in production)
+ api.on_message = function() end
+ api.on_edited_message = function() end
+ api.on_callback_query = function() end
+ api.on_inline_query = function() end
+
+ -- Async stubs (telegram-bot-lua async system)
+ api.async = {
+ run = function() end,
+ stop = function() end,
+ all = function(fns) return {} end,
+ spawn = function(fn) if fn then fn() end end,
+ sleep = function() end,
+ is_running = function() return false end,
+ }
+
+ -- api.run stub — no-op (prevents tests from entering copas.loop)
+ function api.run(opts)
+ record('run', opts)
+ end
+
+ -- process_update stub
+ function api.process_update(update)
+ record('process_update', update)
+ end
+
function api.reset()
api.calls = {}
custom_handlers = {}
end
function api.get_call(method)
for _, call in ipairs(api.calls) do
if call.method == method then return call end
end
return nil
end
function api.get_calls(method)
local results = {}
for _, call in ipairs(api.calls) do
if call.method == method then
table.insert(results, call)
end
end
return results
end
function api.count_calls(method)
local count = 0
for _, call in ipairs(api.calls) do
if call.method == method then count = count + 1 end
end
return count
end
return api
end
return mock_api
diff --git a/src/core/config.lua b/src/core/config.lua
index 44769b2..4366798 100644
--- a/src/core/config.lua
+++ b/src/core/config.lua
@@ -1,163 +1,163 @@
--[[
mattata v2.0 - Configuration Module
Reads configuration from .env file with os.getenv() fallback.
Provides typed access to all configuration values.
]]
local config = {}
local env_values = {}
local loaded = false
-- Parse a .env file into a table
local function parse_env_file(path)
local values = {}
local file = io.open(path, 'r')
if not file then
return values
end
for line in file:lines() do
line = line:match('^%s*(.-)%s*$') -- trim
if line ~= '' and not line:match('^#') then
local key, value = line:match('^([%w_]+)%s*=%s*(.*)$')
if key then
-- Strip surrounding quotes
value = value:match('^"(.*)"$') or value:match("^'(.*)'$") or value
-- Strip inline comments (only for unquoted values)
value = value:match('^(.-)%s+#') or value
values[key] = value
end
end
end
file:close()
return values
end
-- Load .env file (called once)
function config.load(path)
path = path or '.env'
env_values = parse_env_file(path)
loaded = true
end
-- Get a string value with optional default
function config.get(key, default)
if not loaded then
config.load()
end
local value = env_values[key]
if value == nil or value == '' then
value = os.getenv(key)
end
if value == nil or value == '' then
return default
end
return value
end
-- Get a numeric value
function config.get_number(key, default)
local value = config.get(key)
if value == nil then
return default
end
return tonumber(value) or default
end
-- Get a boolean value
function config.is_enabled(key)
local value = config.get(key)
if value == nil then
return false
end
value = value:lower()
return value == 'true' or value == '1' or value == 'yes'
end
-- Get a comma-separated list as a table
function config.get_list(key)
local value = config.get(key)
if not value or value == '' then
return {}
end
local list = {}
for item in value:gmatch('[^,]+') do
item = item:match('^%s*(.-)%s*$')
if item ~= '' then
local num = tonumber(item)
table.insert(list, num or item)
end
end
return list
end
-- Convenience accessors for common config groups
function config.bot_token()
return config.get('BOT_TOKEN')
end
function config.bot_admins()
return config.get_list('BOT_ADMINS')
end
function config.bot_name()
return config.get('BOT_NAME', 'mattata')
end
function config.database()
return {
host = config.get('DATABASE_HOST', 'postgres'),
port = config.get_number('DATABASE_PORT', 5432),
database = config.get('DATABASE_NAME', 'mattata'),
user = config.get('DATABASE_USER', 'mattata'),
password = config.get('DATABASE_PASSWORD', 'changeme')
}
end
function config.redis_config()
return {
host = config.get('REDIS_HOST', 'redis'),
port = config.get_number('REDIS_PORT', 6379),
password = config.get('REDIS_PASSWORD'),
db = config.get_number('REDIS_DB', 0)
}
end
function config.polling()
return {
timeout = config.get_number('POLLING_TIMEOUT', 60),
limit = config.get_number('POLLING_LIMIT', 100)
}
end
function config.webhook()
return {
enabled = config.is_enabled('WEBHOOK_ENABLED'),
url = config.get('WEBHOOK_URL'),
port = config.get_number('WEBHOOK_PORT', 8443),
secret = config.get('WEBHOOK_SECRET')
}
end
function config.ai()
return {
enabled = config.is_enabled('AI_ENABLED'),
openai_key = config.get('OPENAI_API_KEY'),
openai_model = config.get('OPENAI_MODEL', 'gpt-4o'),
anthropic_key = config.get('ANTHROPIC_API_KEY'),
anthropic_model = config.get('ANTHROPIC_MODEL', 'claude-sonnet-4-5-20250929')
}
end
function config.debug()
return config.is_enabled('DEBUG')
end
function config.log_chat()
return config.get_number('LOG_CHAT')
end
-- Version constant
-config.VERSION = '2.0'
+config.VERSION = '2.1'
return config
diff --git a/src/core/database.lua b/src/core/database.lua
index d832951..fe99e4a 100644
--- a/src/core/database.lua
+++ b/src/core/database.lua
@@ -1,283 +1,327 @@
--[[
- mattata v2.0 - PostgreSQL Database Module
+ mattata v2.1 - PostgreSQL Database Module
Uses pgmoon for async-compatible PostgreSQL connections.
- Implements connection pooling, automatic reconnection, and transaction helpers.
+ Implements connection pooling with copas semaphore guards,
+ automatic reconnection, and transaction helpers.
]]
local database = {}
local pgmoon = require('pgmoon')
local config = require('src.core.config')
local logger = require('src.core.logger')
+local copas_sem = require('copas.semaphore')
local pool = {}
local pool_size = 10
local pool_timeout = 30000
+local pool_semaphore = nil
local db_config = nil
-- Initialise pool configuration
local function get_config()
if not db_config then
db_config = config.database()
end
return db_config
end
-- Create a new pgmoon connection
local function create_connection()
local cfg = get_config()
local pg = pgmoon.new({
host = cfg.host,
port = cfg.port,
database = cfg.database,
user = cfg.user,
password = cfg.password
})
local ok, err = pg:connect()
if not ok then
return nil, err
end
pg:settimeout(pool_timeout)
return pg
end
function database.connect()
local cfg = get_config()
pool_size = config.get_number('DATABASE_POOL_SIZE', 10)
pool_timeout = config.get_number('DATABASE_TIMEOUT', 30000)
-- Create initial connection to validate credentials
local pg, err = create_connection()
if not pg then
logger.error('Failed to connect to PostgreSQL: %s', tostring(err))
return false, err
end
table.insert(pool, pg)
+
+ -- Create semaphore to guard concurrent pool access
+ -- max = pool_size, start = pool_size (all permits available), timeout = 30s
+ pool_semaphore = copas_sem.new(pool_size, pool_size, 30)
+
logger.info('Connected to PostgreSQL at %s:%d/%s (pool size: %d)', cfg.host, cfg.port, cfg.database, pool_size)
return true
end
-- Acquire a connection from the pool
function database.acquire()
+ -- Take a semaphore permit (blocks coroutine if pool exhausted, 30s timeout)
+ if pool_semaphore then
+ local ok, err = pool_semaphore:take(1, 30)
+ if not ok then
+ logger.error('Failed to acquire pool permit: %s', tostring(err))
+ return nil, 'Pool exhausted (semaphore timeout)'
+ end
+ end
if #pool > 0 then
return table.remove(pool)
end
-- Pool exhausted — create a new connection
local pg, err = create_connection()
if not pg then
logger.error('Failed to create new connection: %s', tostring(err))
+ -- Return the permit since we failed to use it
+ if pool_semaphore then pool_semaphore:give(1) end
return nil, err
end
return pg
end
-- Release a connection back to the pool
function database.release(pg)
if not pg then return end
if #pool < pool_size then
table.insert(pool, pg)
else
pcall(function() pg:disconnect() end)
end
+ -- Return the semaphore permit
+ if pool_semaphore then pool_semaphore:give(1) end
end
-- Execute a raw SQL query with automatic connection management
function database.query(sql, ...)
local pg, err = database.acquire()
if not pg then
logger.error('Database not connected')
return nil, 'Database not connected'
end
- local result, query_err, partial, num_queries = pg:query(sql)
+ local result, query_err, _, _ = pg:query(sql)
if not result then
-- Check for connection loss and attempt reconnect
if query_err and (query_err:match('closed') or query_err:match('broken') or query_err:match('timeout')) then
logger.warn('Connection lost, attempting reconnect...')
pcall(function() pg:disconnect() end)
+ -- Release the dead connection's permit before reconnect
+ if pool_semaphore then pool_semaphore:give(1) end
pg, err = create_connection()
if pg then
+ -- Re-acquire a permit for the new connection
+ if pool_semaphore then
+ local ok, sem_err = pool_semaphore:take(1, 30)
+ if not ok then
+ pcall(function() pg:disconnect() end)
+ logger.error('Reconnect semaphore acquire failed: %s', tostring(sem_err))
+ return nil, 'Pool exhausted during reconnect'
+ end
+ end
result, query_err = pg:query(sql)
if result then
database.release(pg)
return result
end
+ database.release(pg)
end
logger.error('Reconnect failed for query: %s', tostring(query_err or err))
return nil, query_err or err
end
logger.error('Query failed: %s\nSQL: %s', tostring(query_err), sql)
database.release(pg)
return nil, query_err
end
database.release(pg)
return result
end
-- Execute a parameterized query (manually escape values)
function database.execute(sql, params)
- local pg, err = database.acquire()
+ local pg, _ = database.acquire()
if not pg then
return nil, 'Database not connected'
end
if params then
local escaped = {}
for i, v in ipairs(params) do
if v == nil then
escaped[i] = 'NULL'
elseif type(v) == 'number' then
escaped[i] = tostring(v)
elseif type(v) == 'boolean' then
escaped[i] = v and 'TRUE' or 'FALSE'
else
escaped[i] = pg:escape_literal(tostring(v))
end
end
-- Replace $1, $2, etc. with escaped values
sql = sql:gsub('%$(%d+)', function(n)
return escaped[tonumber(n)] or '$' .. n
end)
end
local result, query_err = pg:query(sql)
if not result then
-- Attempt reconnect on connection failure
if query_err and (query_err:match('closed') or query_err:match('broken') or query_err:match('timeout')) then
logger.warn('Connection lost during execute, reconnecting...')
pcall(function() pg:disconnect() end)
+ -- Release the dead connection's permit before reconnect
+ if pool_semaphore then pool_semaphore:give(1) end
local new_pg
- new_pg, err = create_connection()
+ new_pg, _ = create_connection()
if new_pg then
+ -- Re-acquire a permit for the new connection
+ if pool_semaphore then
+ local ok, sem_err = pool_semaphore:take(1, 30)
+ if not ok then
+ pcall(function() new_pg:disconnect() end)
+ logger.error('Reconnect semaphore acquire failed: %s', tostring(sem_err))
+ return nil, 'Pool exhausted during reconnect'
+ end
+ end
result, query_err = new_pg:query(sql)
if result then
database.release(new_pg)
return result
end
database.release(new_pg)
end
else
database.release(pg)
end
logger.error('Query failed: %s\nSQL: %s', tostring(query_err), sql)
return nil, query_err
end
database.release(pg)
return result
end
-- Run a function inside a transaction (BEGIN / COMMIT / ROLLBACK)
function database.transaction(fn)
- local pg, err = database.acquire()
+ local pg, _ = database.acquire()
if not pg then
return nil, 'Database not connected'
end
local ok, begin_err = pg:query('BEGIN')
if not ok then
database.release(pg)
return nil, begin_err
end
-- Build a scoped query function for this connection
local function scoped_query(sql)
return pg:query(sql)
end
local function scoped_execute(sql, params)
if params then
local escaped = {}
for i, v in ipairs(params) do
if v == nil then
escaped[i] = 'NULL'
elseif type(v) == 'number' then
escaped[i] = tostring(v)
elseif type(v) == 'boolean' then
escaped[i] = v and 'TRUE' or 'FALSE'
else
escaped[i] = pg:escape_literal(tostring(v))
end
end
sql = sql:gsub('%$(%d+)', function(n)
return escaped[tonumber(n)] or '$' .. n
end)
end
return pg:query(sql)
end
local success, result = pcall(fn, scoped_query, scoped_execute)
if success then
pg:query('COMMIT')
database.release(pg)
return result
else
pg:query('ROLLBACK')
database.release(pg)
logger.error('Transaction failed: %s', tostring(result))
return nil, result
end
end
-- Convenience: insert and return the row
function database.insert(table_name, data)
local columns = {}
local values = {}
local params = {}
local i = 1
for k, v in pairs(data) do
table.insert(columns, k)
table.insert(values, '$' .. i)
table.insert(params, v)
i = i + 1
end
local sql = string.format(
'INSERT INTO %s (%s) VALUES (%s) RETURNING *',
table_name,
table.concat(columns, ', '),
table.concat(values, ', ')
)
return database.execute(sql, params)
end
-- Convenience: upsert (INSERT ON CONFLICT UPDATE)
function database.upsert(table_name, data, conflict_keys, update_keys)
local columns = {}
local values = {}
local params = {}
local i = 1
for k, v in pairs(data) do
table.insert(columns, k)
table.insert(values, '$' .. i)
table.insert(params, v)
i = i + 1
end
local updates = {}
for _, k in ipairs(update_keys) do
table.insert(updates, k .. ' = EXCLUDED.' .. k)
end
local sql = string.format(
'INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO UPDATE SET %s RETURNING *',
table_name,
table.concat(columns, ', '),
table.concat(values, ', '),
table.concat(conflict_keys, ', '),
table.concat(updates, ', ')
)
return database.execute(sql, params)
end
-- Get the raw pgmoon connection for advanced usage
function database.connection()
return database.acquire()
end
-- Get current pool stats
function database.pool_stats()
return {
available = #pool,
max_size = pool_size
}
end
function database.disconnect()
for _, pg in ipairs(pool) do
pcall(function() pg:disconnect() end)
end
pool = {}
+ pool_semaphore = nil
logger.info('Disconnected from PostgreSQL (pool drained)')
end
return database
diff --git a/src/core/redis.lua b/src/core/redis.lua
index 5d940d8..c75df8a 100644
--- a/src/core/redis.lua
+++ b/src/core/redis.lua
@@ -1,257 +1,313 @@
--[[
- mattata v2.0 - Redis Connection Module
+ mattata v2.1 - Redis Connection Pool Module
Redis is used as cache/session store only. PostgreSQL is the primary database.
- Includes automatic reconnection with backoff, SCAN replacement for KEYS, and pipeline support.
+ Implements connection pooling with copas semaphore guards,
+ automatic reconnection with backoff, SCAN replacement for KEYS, and pipeline support.
]]
local redis_mod = {}
local redis_lib = require('redis')
local config = require('src.core.config')
local logger = require('src.core.logger')
+local copas_sem = require('copas.semaphore')
-local client = nil
+local pool = {}
+local pool_size = 5
+local pool_semaphore = nil
local redis_cfg = nil
-local reconnect_attempts = 0
-local MAX_RECONNECT_ATTEMPTS = 10
-- Override hgetall to return key-value table instead of flat array
redis_lib.commands.hgetall = redis_lib.command('hgetall', {
response = function(response)
local result = {}
for i = 1, #response, 2 do
result[response[i]] = response[i + 1]
end
return result
end
})
-local function do_connect()
+-- Create a single Redis connection
+local function create_connection()
if not redis_cfg then
redis_cfg = config.redis_config()
end
+ local conn
local ok, err = pcall(function()
- client = redis_lib.connect({
+ conn = redis_lib.connect({
host = redis_cfg.host,
port = redis_cfg.port
})
end)
if not ok then
- return false, err
+ return nil, err
end
if redis_cfg.password and redis_cfg.password ~= '' then
- client:auth(redis_cfg.password)
+ conn:auth(redis_cfg.password)
end
if redis_cfg.db and redis_cfg.db ~= 0 then
- client:select(redis_cfg.db)
+ conn:select(redis_cfg.db)
end
- reconnect_attempts = 0
- return true
+ return conn
end
--- Automatic reconnection with exponential backoff
-local function ensure_connected()
- if client then
- -- Quick ping check
- local ok = pcall(function() client:ping() end)
- if ok then return true end
- logger.warn('Redis connection lost, attempting reconnect...')
- client = nil
+-- Acquire a connection from the pool
+local function acquire()
+ -- Take a semaphore permit (blocks coroutine if pool exhausted)
+ if pool_semaphore then
+ local ok, err = pool_semaphore:take(1, 10)
+ if not ok then
+ logger.error('Redis pool semaphore timeout: %s', tostring(err))
+ return nil, 'Redis pool exhausted'
+ end
end
- while reconnect_attempts < MAX_RECONNECT_ATTEMPTS do
- reconnect_attempts = reconnect_attempts + 1
- local backoff = math.min(2 ^ reconnect_attempts, 30)
- logger.info('Redis reconnect attempt %d/%d (backoff: %ds)', reconnect_attempts, MAX_RECONNECT_ATTEMPTS, backoff)
- local ok, err = do_connect()
+ -- Try pooled connections, discard dead ones
+ while #pool > 0 do
+ local conn = table.remove(pool)
+ local ok = pcall(function() conn:ping() end)
if ok then
- logger.info('Redis reconnected successfully')
- return true
+ return conn
end
- logger.warn('Redis reconnect failed: %s', tostring(err))
- local socket = require('socket')
- socket.sleep(backoff)
+ logger.warn('Discarding dead pooled Redis connection')
end
- logger.error('Redis reconnection failed after %d attempts', MAX_RECONNECT_ATTEMPTS)
- return false
+ -- Create fresh connection
+ local conn, err = create_connection()
+ if not conn then
+ logger.error('Failed to create Redis connection: %s', tostring(err))
+ if pool_semaphore then pool_semaphore:give(1) end
+ return nil, err
+ end
+ return conn
+end
+
+-- Release a connection back to the pool
+local function release(conn)
+ if not conn then return end
+ if #pool < pool_size then
+ table.insert(pool, conn)
+ else
+ pcall(function() conn:quit() end)
+ end
+ if pool_semaphore then pool_semaphore:give(1) end
+end
+
+-- Discard a connection without returning it to the pool
+local function discard(conn)
+ if conn then
+ pcall(function() conn:quit() end)
+ end
+ if pool_semaphore then pool_semaphore:give(1) end
end
-- Safe command wrapper with auto-reconnect
-local function safe_call(method, ...)
- if not ensure_connected() then
+-- method_name is a string like 'get', 'set', etc.
+local function safe_call(method_name, ...)
+ local conn, err = acquire()
+ if not conn then
return nil
end
- local ok, result = pcall(method, client, ...)
+ local ok, result = pcall(function(...)
+ return conn[method_name](conn, ...)
+ end, ...)
if not ok then
- -- Connection may have dropped mid-call
- logger.warn('Redis command failed: %s — retrying after reconnect', tostring(result))
- client = nil
- if ensure_connected() then
- ok, result = pcall(method, client, ...)
- if ok then return result end
+ -- Connection may have dropped mid-call — discard and retry once
+ logger.warn('Redis %s failed: %s — retrying after reconnect', method_name, tostring(result))
+ discard(conn)
+ conn, err = acquire()
+ if not conn then
+ logger.error('Redis reconnect failed: %s', tostring(err))
+ return nil
+ end
+ ok, result = pcall(function(...)
+ return conn[method_name](conn, ...)
+ end, ...)
+ if not ok then
+ logger.error('Redis %s failed after reconnect: %s', method_name, tostring(result))
+ discard(conn)
+ return nil
end
- logger.error('Redis command failed after reconnect: %s', tostring(result))
- return nil
end
+ release(conn)
return result
end
function redis_mod.connect()
redis_cfg = config.redis_config()
- local ok, err = do_connect()
- if not ok then
+ pool_size = config.get_number('REDIS_POOL_SIZE', 5)
+
+ -- Create initial connection to validate credentials
+ local conn, err = create_connection()
+ if not conn then
logger.error('Failed to connect to Redis: %s', tostring(err))
return false, err
end
- logger.info('Connected to Redis at %s:%d (db %d)', redis_cfg.host, redis_cfg.port, redis_cfg.db or 0)
+ table.insert(pool, conn)
+
+ -- Create semaphore to guard concurrent pool access
+ -- max = pool_size, start = pool_size (all permits available), timeout = 10s
+ pool_semaphore = copas_sem.new(pool_size, pool_size, 10)
+
+ logger.info('Connected to Redis at %s:%d (db %d, pool size: %d)', redis_cfg.host, redis_cfg.port, redis_cfg.db or 0, pool_size)
return true
end
--- Get the raw redis client
+-- Get the raw redis client (deprecated — prefer using proxy functions)
function redis_mod.client()
- ensure_connected()
- return client
+ logger.warn('redis.client() is deprecated — use redis proxy functions instead')
+ local conn = acquire()
+ return conn
end
-- Proxy common operations with auto-reconnect
function redis_mod.get(key)
- return safe_call(client.get, key)
+ return safe_call('get', key)
end
function redis_mod.set(key, value)
- return safe_call(client.set, key, value)
+ return safe_call('set', key, value)
end
function redis_mod.setex(key, ttl, value)
- return safe_call(client.setex, key, ttl, value)
+ return safe_call('setex', key, ttl, value)
end
function redis_mod.setnx(key, value)
- return safe_call(client.setnx, key, value)
+ return safe_call('setnx', key, value)
end
function redis_mod.del(key)
- return safe_call(client.del, key)
+ return safe_call('del', key)
end
function redis_mod.exists(key)
- return safe_call(client.exists, key)
+ return safe_call('exists', key)
end
function redis_mod.expire(key, ttl)
- return safe_call(client.expire, key, ttl)
+ return safe_call('expire', key, ttl)
end
function redis_mod.incr(key)
- return safe_call(client.incr, key)
+ return safe_call('incr', key)
end
function redis_mod.incrby(key, amount)
- return safe_call(client.incrby, key, amount)
+ return safe_call('incrby', key, amount)
end
function redis_mod.hget(key, field)
- return safe_call(client.hget, key, field)
+ return safe_call('hget', key, field)
end
function redis_mod.hset(key, field, value)
- return safe_call(client.hset, key, field, value)
+ return safe_call('hset', key, field, value)
end
function redis_mod.hdel(key, field)
- return safe_call(client.hdel, key, field)
+ return safe_call('hdel', key, field)
end
function redis_mod.hgetall(key)
- return safe_call(client.hgetall, key)
+ return safe_call('hgetall', key)
end
function redis_mod.hexists(key, field)
- return safe_call(client.hexists, key, field)
+ return safe_call('hexists', key, field)
end
function redis_mod.hincrby(key, field, increment)
- return safe_call(client.hincrby, key, field, increment)
+ return safe_call('hincrby', key, field, increment)
end
function redis_mod.sadd(key, value)
- return safe_call(client.sadd, key, value)
+ return safe_call('sadd', key, value)
end
function redis_mod.srem(key, value)
- return safe_call(client.srem, key, value)
+ return safe_call('srem', key, value)
end
function redis_mod.sismember(key, value)
- return safe_call(client.sismember, key, value)
+ return safe_call('sismember', key, value)
end
function redis_mod.smembers(key)
- return safe_call(client.smembers, key)
+ return safe_call('smembers', key)
end
-- List operations (used by AI plugin)
function redis_mod.rpush(key, value)
- return safe_call(client.rpush, key, value)
+ return safe_call('rpush', key, value)
end
function redis_mod.lrange(key, start, stop)
- return safe_call(client.lrange, key, start, stop)
+ return safe_call('lrange', key, start, stop)
end
function redis_mod.ltrim(key, start, stop)
- return safe_call(client.ltrim, key, start, stop)
+ return safe_call('ltrim', key, start, stop)
end
-- SCAN-based iteration — replaces all KEYS usage
-- Returns all keys matching pattern without blocking
function redis_mod.scan(pattern)
- if not ensure_connected() then
+ local conn = acquire()
+ if not conn then
return {}
end
local results = {}
local cursor = '0'
repeat
local ok, reply = pcall(function()
- return client:scan(cursor, { match = pattern, count = 100 })
+ return conn:scan(cursor, { match = pattern, count = 100 })
end)
- if not ok or not reply then break end
+ if not ok or not reply then
+ discard(conn)
+ return results
+ end
cursor = reply[1]
for _, key in ipairs(reply[2]) do
table.insert(results, key)
end
until cursor == '0'
+ release(conn)
return results
end
-- DEPRECATED: kept for compatibility but uses SCAN internally
function redis_mod.keys(pattern)
logger.warn('redis.keys() called — prefer redis.scan() to avoid blocking')
return redis_mod.scan(pattern)
end
-- Pipeline support: batch multiple commands and execute together
function redis_mod.pipeline(fn)
- if not ensure_connected() then
+ local conn = acquire()
+ if not conn then
return nil
end
- local pipeline = client:pipeline()
+ local pipeline = conn:pipeline()
fn(pipeline)
local ok, results = pcall(function()
return pipeline:execute()
end)
if not ok then
logger.error('Redis pipeline failed: %s', tostring(results))
+ discard(conn)
return nil
end
+ release(conn)
return results
end
function redis_mod.disconnect()
- if client then
- pcall(function() client:quit() end)
- client = nil
- logger.info('Disconnected from Redis')
+ for _, conn in ipairs(pool) do
+ pcall(function() conn:quit() end)
end
+ pool = {}
+ pool_semaphore = nil
+ logger.info('Disconnected from Redis (pool drained)')
end
return redis_mod
diff --git a/src/core/router.lua b/src/core/router.lua
index dba32b0..321828a 100644
--- a/src/core/router.lua
+++ b/src/core/router.lua
@@ -1,430 +1,410 @@
--[[
- mattata v2.0 - Event Router
+ mattata v2.1 - Event Router
Dispatches Telegram updates through middleware pipeline to plugins.
- Handles messages, callback queries, inline queries, and other events.
+ Uses copas coroutines via telegram-bot-lua's async system for concurrent
+ update processing — each update runs in its own coroutine.
]]
local router = {}
local json = require('dkjson')
-local socket = require('socket')
+local copas = require('copas')
local config = require('src.core.config')
local logger = require('src.core.logger')
local middleware_pipeline = require('src.core.middleware')
local session = require('src.core.session')
local permissions = require('src.core.permissions')
local i18n = require('src.core.i18n')
local tools
local api, loader, ctx_base
-- Import middleware modules
local mw_blocklist = require('src.middleware.blocklist')
local mw_rate_limit = require('src.middleware.rate_limit')
local mw_user_tracker = require('src.middleware.user_tracker')
local mw_language = require('src.middleware.language')
local mw_federation = require('src.middleware.federation')
local mw_captcha = require('src.middleware.captcha')
local mw_stats = require('src.middleware.stats')
function router.init(api_ref, tools_ref, loader_ref, ctx_base_ref)
api = api_ref
tools = tools_ref
loader = loader_ref
ctx_base = ctx_base_ref
-- Register middleware in order
middleware_pipeline.use(mw_blocklist)
middleware_pipeline.use(mw_rate_limit)
middleware_pipeline.use(mw_federation)
middleware_pipeline.use(mw_captcha)
middleware_pipeline.use(mw_user_tracker)
middleware_pipeline.use(mw_language)
middleware_pipeline.use(mw_stats)
end
-- Build a fresh context for each update
-- Admin check is lazy — only resolved when ctx:check_admin() is called
local function build_ctx(message)
local ctx = {}
for k, v in pairs(ctx_base) do
ctx[k] = v
end
ctx.is_group = message.chat and message.chat.type ~= 'private'
ctx.is_supergroup = message.chat and message.chat.type == 'supergroup'
ctx.is_private = message.chat and message.chat.type == 'private'
ctx.is_global_admin = message.from and permissions.is_global_admin(message.from.id) or false
-- Lazy admin check: only makes API call when first accessed
-- Caches result for the lifetime of this context
local admin_resolved = false
local admin_value = false
ctx.is_admin = false -- default for non-admin reads
function ctx:check_admin()
if admin_resolved then
return admin_value
end
admin_resolved = true
if ctx.is_global_admin then
admin_value = true
elseif ctx.is_group and message.from then
admin_value = permissions.is_group_admin(api, message.chat.id, message.from.id)
end
ctx.is_admin = admin_value
return admin_value
end
-- For backward compat: admin plugins that check ctx.is_admin will still
-- need to call ctx:check_admin() first. The router does this for admin_only plugins.
ctx.is_mod = false
return ctx
end
-- Sort/normalise a message object (ported from v1 mattata.sort_message)
local function sort_message(message)
message.text = message.text or message.caption or ''
-- Normalise /command_arg to /command arg
message.text = message.text:gsub('^(/[%a]+)_', '%1 ')
-- Deep-link support
if message.text:match('^[/!#]start .-$') then
message.text = '/' .. message.text:match('^[/!#]start (.-)$')
end
-- Shorthand reply alias
if message.reply_to_message then
message.reply = message.reply_to_message
message.reply_to_message = nil
end
-- Normalise language code
if message.from and message.from.language_code then
local lc = message.from.language_code:lower():gsub('%-', '_')
if #lc == 2 and lc ~= 'en' then
lc = lc .. '_' .. lc
elseif #lc == 2 or lc == 'root' then
lc = 'en_us'
end
message.from.language_code = lc
end
-- Detect media
message.is_media = message.photo or message.video or message.audio or message.voice
or message.document or message.sticker or message.animation or message.video_note or false
-- Detect service messages
message.is_service_message = (message.new_chat_members or message.left_chat_member
or message.new_chat_title or message.new_chat_photo or message.pinned_message
or message.group_chat_created or message.supergroup_chat_created) and true or false
-- Entity-based text mentions -> ID substitution
if message.entities then
for _, entity in ipairs(message.entities) do
if entity.type == 'text_mention' and entity.user then
local name = message.text:sub(entity.offset + 1, entity.offset + entity.length)
message.text = message.text:gsub(name, tostring(entity.user.id), 1)
end
end
end
-- Process caption entities as entities
if message.caption_entities then
message.entities = message.caption_entities
message.caption_entities = nil
end
-- Sort reply recursively
if message.reply then
message.reply = sort_message(message.reply)
end
return message
end
-- Extract command from message text
local function extract_command(text, bot_username)
if not text then return nil, nil end
local cmd, args = text:match('^[/!#]([%w_]+)@?' .. (bot_username or '') .. '%s*(.*)')
if not cmd then
cmd, args = text:match('^[/!#]([%w_]+)%s*(.*)')
end
if cmd then
cmd = cmd:lower()
args = args ~= '' and args or nil
end
return cmd, args
end
-- Resolve aliases for a chat (with Redis caching)
local function resolve_alias(message, redis_mod)
if not message.text:match('^[/!#][%w_]+') then return message end
if not message.chat or message.chat.type == 'private' then return message end
local command, rest = message.text:lower():match('^[/!#]([%w_]+)(.*)')
if not command then return message end
-- Cache alias lookups with TTL instead of hgetall on every message
local cache_key = 'cache:aliases:' .. message.chat.id
local cached_aliases = redis_mod.get(cache_key)
local aliases
if cached_aliases then
local ok, decoded = pcall(json.decode, cached_aliases)
if ok and decoded then
aliases = decoded
end
end
if not aliases then
aliases = redis_mod.hgetall('chat:' .. message.chat.id .. ':aliases')
if type(aliases) == 'table' then
pcall(function()
redis_mod.setex(cache_key, 300, json.encode(aliases))
end)
end
end
if type(aliases) == 'table' then
for alias, original in pairs(aliases) do
if command == alias then
message.text = '/' .. original .. (rest or '')
message.is_alias = true
break
end
end
end
return message
end
-- Process action state (multi-step commands)
-- Fixed: save message_id before nil'ing message.reply
local function process_action(message, ctx)
if message.text and message.chat and message.reply
and message.reply.from and message.reply.from.id == api.info.id then
local reply_message_id = message.reply.message_id
local action = session.get_action(message.chat.id, reply_message_id)
if action then
message.text = action .. ' ' .. message.text
message.reply = nil
session.del_action(message.chat.id, reply_message_id)
end
end
return message
end
-- Handle a message update
local function on_message(message)
-- Validate
if not message or not message.from then return end
if message.date and message.date < os.time() - 10 then return end
-- Sort/normalise
message = sort_message(message)
message = process_action(message, ctx_base)
message = resolve_alias(message, ctx_base.redis)
-- Build context and run middleware
local ctx = build_ctx(message)
local should_continue
ctx, should_continue = middleware_pipeline.run(ctx, message)
if not should_continue then return end
-- Dispatch command to matching plugin
local cmd, args = extract_command(message.text, api.info.username)
- local command_handled = false
if cmd then
local plugin = loader.get_by_command(cmd)
if plugin and plugin.on_message then
if not session.is_plugin_disabled(message.chat.id, plugin.name) or loader.is_permanent(plugin.name) then
-- Check permission requirements
if plugin.global_admin_only and not ctx.is_global_admin then
return
end
-- Resolve admin status only for admin_only plugins (lazy check)
if plugin.admin_only then
ctx:check_admin()
if not ctx.is_admin and not ctx.is_global_admin then
return api.send_message(message.chat.id, ctx.lang and ctx.lang.errors and ctx.lang.errors.admin or 'You need to be an admin to use this command.')
end
end
if plugin.group_only and ctx.is_private then
return api.send_message(message.chat.id, ctx.lang and ctx.lang.errors and ctx.lang.errors.supergroup or 'This command can only be used in groups.')
end
message.command = cmd
message.args = args
local ok, err = pcall(plugin.on_message, api, message, ctx)
if not ok then
logger.error('Plugin %s.on_message error: %s', plugin.name, tostring(err))
if config.log_chat() then
api.send_message(config.log_chat(), string.format(
'<pre>[%s] %s error:\n%s\nFrom: %s\nText: %s</pre>',
os.date('%X'), plugin.name,
tools.escape_html(tostring(err)),
message.from.id,
tools.escape_html(message.text or '')
), 'html')
end
end
- command_handled = true
end
end
end
-- Run passive handlers (on_new_message) for all non-disabled plugins
for _, plugin in ipairs(loader.get_plugins()) do
if plugin.on_new_message and not session.is_plugin_disabled(message.chat.id, plugin.name) then
local ok, err = pcall(plugin.on_new_message, api, message, ctx)
if not ok then
logger.error('Plugin %s.on_new_message error: %s', plugin.name, tostring(err))
end
end
-- Handle member join events
if message.new_chat_members and plugin.on_member_join then
local ok, err = pcall(plugin.on_member_join, api, message, ctx)
if not ok then
logger.error('Plugin %s.on_member_join error: %s', plugin.name, tostring(err))
end
end
end
end
-- Handle callback query (routed through middleware for blocklist + rate limit)
local function on_callback_query(callback_query)
if not callback_query or not callback_query.from then return end
if not callback_query.data then return end
local message = callback_query.message or {
chat = {},
message_id = callback_query.inline_message_id,
from = callback_query.from
}
-- Parse plugin_name:data format
local plugin_name, cb_data = callback_query.data:match('^(.-):(.*)$')
if not plugin_name then return end
local plugin = loader.get_by_name(plugin_name)
if not plugin or not plugin.on_callback_query then return end
callback_query.data = cb_data
-- Build context and run basic middleware (blocklist + rate limit)
local ctx = build_ctx(message)
-- Check blocklist for callback user
if session.is_globally_blocklisted(callback_query.from.id) then
return
end
-- Load language for callback user
local lang_code = session.get_setting(callback_query.from.id, 'language') or 'en_gb'
ctx.lang = i18n.get(lang_code)
local ok, err = pcall(plugin.on_callback_query, api, callback_query, message, ctx)
if not ok then
logger.error('Plugin %s.on_callback_query error: %s', plugin_name, tostring(err))
end
end
-- Handle inline query
local function on_inline_query(inline_query)
if not inline_query or not inline_query.from then return end
if session.is_globally_blocklisted(inline_query.from.id) then return end
local ctx = build_ctx({ from = inline_query.from, chat = { type = 'private' } })
local lang_code = session.get_setting(inline_query.from.id, 'language') or 'en_gb'
ctx.lang = i18n.get(lang_code)
for _, plugin in ipairs(loader.get_plugins()) do
if plugin.on_inline_query then
local ok, err = pcall(plugin.on_inline_query, api, inline_query, ctx)
if not ok then
logger.error('Plugin %s.on_inline_query error: %s', plugin.name, tostring(err))
end
end
end
end
--- Run cron jobs asynchronously in coroutines
-local function run_cron_async()
- for _, plugin in ipairs(loader.get_plugins()) do
- if plugin.cron then
- local co = coroutine.create(function()
- local ok, err = pcall(plugin.cron, api, ctx_base)
- if not ok then
- logger.error('Plugin %s cron error: %s', plugin.name, tostring(err))
- end
- end)
- coroutine.resume(co)
- end
- end
-end
-
--- Main polling loop
+-- Concurrent polling loop using telegram-bot-lua's async system
function router.run()
- local last_update = 0
- local last_cron = os.date('%M')
- local last_stats_flush = 0
local polling = config.polling()
- while true do
- local success = api.get_updates(
- polling.timeout,
- last_update + 1,
- polling.limit,
- json.encode({
- 'message', 'edited_message', 'callback_query', 'inline_query',
- 'chat_join_request', 'chat_member', 'my_chat_member',
- 'message_reaction'
- })
- )
-
- if success and success.result then
- for _, update in ipairs(success.result) do
- last_update = update.update_id
- local start_time = socket.gettime()
-
- if update.message or update.edited_message then
- local msg = update.message or update.edited_message
- if update.edited_message then
- msg.is_edited = true
- end
- local ok, err = pcall(on_message, msg)
- if not ok then
- logger.error('on_message error: %s', tostring(err))
- end
- elseif update.callback_query then
- local ok, err = pcall(on_callback_query, update.callback_query)
- if not ok then
- logger.error('on_callback_query error: %s', tostring(err))
- end
- elseif update.inline_query then
- local ok, err = pcall(on_inline_query, update.inline_query)
- if not ok then
- logger.error('on_inline_query error: %s', tostring(err))
- end
- end
+ -- Register telegram-bot-lua handler callbacks
+ -- api.process_update() dispatches to these inside per-update copas coroutines
+ api.on_message = function(msg)
+ local ok, err = pcall(on_message, msg)
+ if not ok then logger.error('on_message error: %s', tostring(err)) end
+ end
- if config.debug() then
- logger.debug('Update #%d processed in %.3fs', update.update_id, socket.gettime() - start_time)
- end
- end
- else
- logger.error('Failed to retrieve updates from Telegram API')
- end
+ api.on_edited_message = function(msg)
+ msg.is_edited = true
+ local ok, err = pcall(on_message, msg)
+ if not ok then logger.error('on_edited_message error: %s', tostring(err)) end
+ end
- -- Minutely cron jobs (async via coroutines)
- if last_cron ~= os.date('%M') then
- last_cron = os.date('%M')
- run_cron_async()
- end
+ api.on_callback_query = function(cb)
+ local ok, err = pcall(on_callback_query, cb)
+ if not ok then logger.error('on_callback_query error: %s', tostring(err)) end
+ end
- -- Flush stats counters to PostgreSQL every 5 minutes
- local now = os.time()
- if now - last_stats_flush >= 300 then
- last_stats_flush = now
- local co = coroutine.create(function()
- local ok, err = pcall(mw_stats.flush, ctx_base.db, ctx_base.redis)
- if not ok then
- logger.error('Stats flush error: %s', tostring(err))
+ api.on_inline_query = function(iq)
+ local ok, err = pcall(on_inline_query, iq)
+ if not ok then logger.error('on_inline_query error: %s', tostring(err)) end
+ end
+
+ -- Cron: copas background thread, runs every 60s
+ copas.addthread(function()
+ while true do
+ copas.pause(60)
+ for _, plugin in ipairs(loader.get_plugins()) do
+ if plugin.cron then
+ copas.addthread(function()
+ local ok, err = pcall(plugin.cron, api, ctx_base)
+ if not ok then
+ logger.error('Plugin %s cron error: %s', plugin.name, tostring(err))
+ end
+ end)
end
- end)
- coroutine.resume(co)
+ end
end
- end
+ end)
+
+ -- Stats flush: copas background thread, runs every 300s
+ copas.addthread(function()
+ while true do
+ copas.pause(300)
+ local ok, err = pcall(mw_stats.flush, ctx_base.db, ctx_base.redis)
+ if not ok then logger.error('Stats flush error: %s', tostring(err)) end
+ end
+ end)
+
+ -- Start concurrent polling loop
+ -- api.run() -> api.async.run() which:
+ -- 1. Swaps api.request to copas-based api.async.request
+ -- 2. Spawns polling coroutine calling get_updates in a loop
+ -- 3. For each update, spawns NEW coroutine -> api.process_update -> handlers above
+ -- 4. Calls copas.loop()
+ api.run({
+ timeout = polling.timeout,
+ limit = polling.limit,
+ allowed_updates = {
+ 'message', 'edited_message', 'callback_query', 'inline_query',
+ 'chat_join_request', 'chat_member', 'my_chat_member',
+ 'message_reaction'
+ }
+ })
end
return router
File Metadata
Details
Attached
Mime Type
text/x-diff
Expires
Sun, May 17, 9:13 AM (1 d, 12 h)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
63003
Default Alt Text
(77 KB)
Attached To
Mode
R69 mattata
Attached
Detach File
Event Timeline