Page MenuHomePhabricator (Chris)

No OneTemporary

Authored By
Unknown
Size
22 KB
Referenced Files
None
Subscribers
None
diff --git a/util/network/chat.cpp b/util/network/chat.cpp
index 3696901d..f68fa455 100644
--- a/util/network/chat.cpp
+++ b/util/network/chat.cpp
@@ -1,314 +1,314 @@
#include "chat.h"
#include <stdexcept>
using namespace Network;
using namespace Chat;
Message::Message():
type(Unknown){
}
Message::Message(const Message::Type & type):
type(type){
}
Message::Message(Socket socket):
type(Unknown){
read(socket);
}
Message::Message(const Message::Type & type, const std::string & sender, const std::string & message):
type(type),
sender(sender),
message(message){
}
Message::Message(const Message & copy):
type(copy.type),
sender(copy.sender),
message(copy.message),
parameters(copy.parameters){
}
Message::~Message(){
}
const Message & Message::operator=(const Message & copy){
type = copy.type;
sender = copy.sender;
message = copy.message;
parameters = copy.parameters;
return *this;
}
static std::string convertType(const Message::Type & type){
std::string converted;
switch (type){
case Message::Ping:
converted = "ping";
break;
case Message::Chat:
converted = "chat";
break;
case Message::Command:
converted = "command";
break;
case Message::Unknown:
default:
converted = "unknown";
break;
}
return converted;
}
static Message::Type convertType(const std::string & type){
Message::Type converted = Message::Unknown;
if (type == "ping"){
converted = Message::Ping;
} else if (type == "chat"){
converted = Message::Chat;
} else if (type == "command"){
converted = Message::Command;
} else if (type == "unknown"){
converted = Message::Unknown;
}
return converted;
}
void Message::read(Socket socket){
int16_t size = ::Network::read16(socket);
char * buffer = new char[size];
::Network::readBytes(socket, (uint8_t*) buffer, size);
char * position = buffer;
uint16_t nextSize = 0;
std::string typeString;
position = ::Network::parse16(position, &nextSize);
position = ::Network::parseString(position, &typeString, nextSize + 1);
type = convertType(typeString);
nextSize = 0;
position = ::Network::parse16(position, &nextSize);
position = ::Network::parseString(position, &sender, nextSize + 1);
nextSize = 0;
position = ::Network::parse16(position, &nextSize);
position = ::Network::parseString(position, &message, nextSize + 1);
uint16_t vectorSize = 0;
position = ::Network::parse16(position, &vectorSize);
for (int i = 0; i < vectorSize; ++i){
std::string next;
nextSize = 0;
position = ::Network::parse16(position, &nextSize);
position = ::Network::parseString(position, &next, nextSize + 1);
parameters.push_back(next);
}
delete[] buffer;
}
void Message::send(Socket socket){
// Type
std::string stringType = convertType(type);
int size = sizeof(uint16_t) + stringType.size()+1 +
sizeof(uint16_t) + sender.size()+1 +
sizeof(uint16_t) + message.size()+1 +
sizeof(uint16_t) + parameters.size();
for (std::vector<std::string>::iterator i = parameters.begin(); i != parameters.end(); ++i){
size+=sizeof(uint16_t) + (*i).size()+1;
}
size+=sizeof(uint16_t);
char * buffer = new char[size + sizeof(uint16_t)];
char * position = buffer;
position = ::Network::dump16(position, size);
position = ::Network::dump16(position, stringType.size());
position = ::Network::dumpStr(position, stringType);
position = ::Network::dump16(position, sender.size());
position = ::Network::dumpStr(position, sender);
position = ::Network::dump16(position, message.size());
position = ::Network::dumpStr(position, message);
position = ::Network::dump16(position, parameters.size());
for (std::vector<std::string>::iterator i = parameters.begin(); i != parameters.end(); ++i){
position = ::Network::dump16(position, (*i).size());
position = ::Network::dumpStr(position, (*i));
}
::Network::sendBytes(socket, (uint8_t*) buffer, size + sizeof(uint16_t));
delete[] buffer;
}
const Message::Type & Message::getType() const{
return type;
}
const std::string & Message::getName() const{
return sender;
}
const std::string & Message::getMessage() const{
return message;
}
void Message::setParameters(const std::vector<std::string> & parameters){
this->parameters = parameters;
}
const std::vector<std::string> & Message::getParameters() const{
return parameters;
}
Threadable::Threadable(){
}
Threadable::~Threadable(){
}
static void * run_thread(void * t){
Threadable * thread = (Threadable *)t;
thread->run();
return NULL;
}
void Threadable::start(){
::Util::Thread::createThread(&thread, NULL, (::Util::Thread::ThreadFunction) run_thread, this);
}
void Threadable::join(){
::Util::Thread::joinThread(thread);
}
Client::Client(int id, Network::Socket socket):
id(id),
socket(socket),
end(false),
valid(true){
}
Client::~Client(){
}
void Client::run(){
while (!end){
try {
::Util::Thread::ScopedLock scope(lock);
Message message(socket);
messages.push(message);
} catch (const Network::MessageEnd & ex){
end = true;
}
}
valid = false;
}
int Client::getId() const {
return id;
}
void Client::sendMessage(Message & message){
::Util::Thread::ScopedLock scope(lock);
try{
message.send(socket);
} catch (const ::Network::NetworkException & ex){
valid = false;
}
}
bool Client::hasMessages() const{
::Util::Thread::ScopedLock scope(lock);
return !messages.empty();
}
Message Client::nextMessage() {
::Util::Thread::ScopedLock scope(lock);
Message message = messages.front();
messages.pop();
return message;
}
void Client::shutdown(){
::Util::Thread::ScopedLock scope(lock);
end = true;
Network::close(socket);
join();
}
bool Client::isValid() const {
::Util::Thread::ScopedLock scope(lock);
return valid;
}
Server::Server(int port):
end(false){
- remote = Network::open(port);
+ remote = Network::openReliable(port);
Network::listen(remote);
Global::debug(0) << "Waiting for a connection on port " << port << std::endl;
}
Server::~Server(){
}
void Server::run(){
int idList = 0;
while (!end){
Util::ReferenceCount<Client> client = Util::ReferenceCount<Client>(new Client(idList++, Network::accept(remote)));
client->start();
clients.push_back(client);
Global::debug(0) << "Got a connection" << std::endl;
}
}
void Server::poll(){
for (std::vector< Util::ReferenceCount<Client> >::iterator i = clients.begin(); i != clients.end(); ++i){
Util::ReferenceCount<Client> client = *i;
while (client->hasMessages()){
Message message = client->nextMessage();
relay(client->getId(), message);
messages.push(message);
}
}
}
void Server::cleanup(){
for (std::vector< Util::ReferenceCount<Client> >::iterator i = clients.begin(); i != clients.end(); ++i){
Util::ReferenceCount<Client> client = *i;
Message ping(Message::Ping);
client->sendMessage(ping);
if (!client->isValid()){
client->shutdown();
i = clients.erase(i);
}
}
}
bool Server::hasMessages() const{
return !messages.empty();
}
Message Server::nextMessage() {
Message message = messages.front();
messages.pop();
return message;
}
void Server::global(Message & message){
for (std::vector< Util::ReferenceCount<Client> >::iterator i = clients.begin(); i != clients.end(); ++i){
Util::ReferenceCount<Client> client = *i;
client->sendMessage(message);
}
}
void Server::relay(int id, Message & message){
for (std::vector< Util::ReferenceCount<Client> >::iterator i = clients.begin(); i != clients.end(); ++i){
Util::ReferenceCount<Client> client = *i;
if (client->getId() != id){
client->sendMessage(message);
}
}
}
void Server::shutdown(){
::Util::Thread::ScopedLock scope(lock);
end = true;
for (std::vector< Util::ReferenceCount<Client> >::iterator i = clients.begin(); i != clients.end(); ++i){
Util::ReferenceCount<Client> client = *i;
client->shutdown();
}
Network::close(remote);
join();
}
diff --git a/util/network/network.cpp b/util/network/network.cpp
index 11118082..0f299c9f 100644
--- a/util/network/network.cpp
+++ b/util/network/network.cpp
@@ -1,332 +1,354 @@
#ifdef HAVE_NETWORKING
#include "hawknl/nl.h"
#endif
#include "network.h"
#include "util/debug.h"
#include <string>
#include <sstream>
#include <string.h>
#include "util/system.h"
#include "util/compress.h"
#include "util/thread.h"
#ifdef HAVE_NETWORKING
#ifdef WII
#include <network.h>
#elif defined(WINDOWS)
#include <winsock.h>
#else
#include <arpa/inet.h>
#endif
#else
#ifndef htonl
#define htonl(x) x
#endif
#ifndef htons
#define htons(x) x
#endif
#ifndef ntohl
#define ntohl(x) x
#endif
#ifndef ntohs
#define ntohs(x) x
#endif
#endif
using namespace std;
/* TODO: Wrap open_sockets with a mutex */
namespace Network{
NetworkException::~NetworkException() throw (){
}
MessageEnd::MessageEnd(){
}
InvalidPortException::InvalidPortException( int port, const string message ):
NetworkException(""){
ostringstream num;
num << port;
num << ". ";
num << message;
this->setMessage( "Invalid port " + num.str() );
}
/*
template <typename M>
int messageSize(const M& message);
*/
/*
template <>
int messageSize<Message>(Message const & message){
return message.size();
}
template <>
int messageSize<Message*>(Message* const & message){
return message->size();
}
*/
/*
template <class M>
uint8_t * messageDump(const M& message, uint8_t * buffer);
template <>
uint8_t * messageDump<Message>(const Message & message, uint8_t * buffer){
return message.dump(buffer);
}
template <>
uint8_t * messageDump<Message*>(Message* const & message, uint8_t * buffer){
return message->dump(buffer);
}
*/
#ifdef HAVE_NETWORKING
static string getHawkError(){
return string(" HawkNL error: '") +
string(nlGetErrorStr(nlGetError())) +
string("' HawkNL system error: '") +
string(nlGetSystemErrorStr(nlGetSystemError()));
}
template<typename X>
static X readX(Socket socket){
X data;
readBytes(socket, (uint8_t*) &data, sizeof(X));
return data;
}
int8_t read8(Socket socket){
return readX<uint8_t>(socket);
}
int16_t read16(Socket socket){
return ntohs(readX<uint16_t>(socket));
}
int32_t read32(Socket socket){
return ntohl(readX<uint32_t>(socket));
}
void send16(Socket socket, int16_t bytes){
bytes = htons(bytes);
sendBytes(socket, (uint8_t *) &bytes, sizeof(bytes));
}
char * dump16(char * where, int16_t bytes){
bytes = htons(bytes);
*(uint16_t*) where = bytes;
return where + sizeof(uint16_t);
}
char * parse16(char * where, uint16_t * out){
*out = ntohs(*(uint16_t*) where);
return where + sizeof(uint16_t);
}
char * parseString(char * where, string * out, uint16_t length){
*out = string(where);
return where + length;
}
char * dumpStr(char * where, const std::string & str){
memcpy(where, str.c_str(), str.size() + 1);
return where + str.size() + 1;
}
string readStr(Socket socket, const uint16_t length){
char buffer[length + 1];
NLint bytes = nlRead(socket, buffer, length);
if (bytes == NL_INVALID){
throw NetworkException(string("Could not read string.") + getHawkError());
}
buffer[length] = 0;
bytes += 1;
return string(buffer);
}
void sendStr(Socket socket, const string & str){
if (nlWrite(socket, str.c_str(), str.length() + 1) != (signed)(str.length() + 1)){
throw NetworkException( string("Could not write string.") + getHawkError() );
}
}
void sendBytes(Socket socket, const uint8_t * data, int length){
const uint8_t * position = data;
int written = 0;
while ( written < length ){
int bytes = nlWrite(socket, position, length - written);
if (bytes == NL_INVALID){
throw NetworkException(string("Could not send bytes.") + getHawkError());
}
written += bytes;
position += bytes;
}
}
void readBytes(Socket socket, uint8_t * data, int length){
uint8_t * position = data;
int read = 0;
while (read < length){
int bytes = nlRead(socket, position, length - read);
if (bytes == NL_INVALID){
switch (nlGetError()){
case NL_MESSAGE_END : throw MessageEnd();
default : throw NetworkException(string("Could not read bytes.") + getHawkError());
}
}
read += bytes;
position += bytes;
}
}
Util::Thread::Lock socketsLock;
-Socket open(int port) throw (InvalidPortException){
+Socket openReliable(int port){
// NLsocket server = nlOpen( port, NL_RELIABLE_PACKETS );
- Global::debug(1, "network") << "Attemping to open port " << port << endl;
- Socket server = nlOpen( port, NL_RELIABLE );
+ Global::debug(1, "network") << "Attemping to open reliable port " << port << endl;
+ Socket server = nlOpen(port, NL_RELIABLE);
+ /* server will either be NL_INVALID (-1) or some low integer. hawknl
+ * sockets are mapped internally to real sockets, so don't be surprised
+ * if you get a socket back like 0.
+ */
+ if (server == NL_INVALID){
+ throw InvalidPortException(port, nlGetSystemErrorStr(nlGetSystemError()));
+ }
+ Global::debug(1, "network") << "Successfully opened a socket: " << server << endl;
+ Util::Thread::acquireLock(&socketsLock);
+ open_sockets.push_back(server);
+ Util::Thread::releaseLock(&socketsLock);
+ return server;
+}
+
+Socket openUnreliable(int port){
+ // NLsocket server = nlOpen( port, NL_RELIABLE_PACKETS );
+ Global::debug(1, "network") << "Attemping to open unreliable port " << port << endl;
+ Socket server = nlOpen(port, NL_UNRELIABLE);
/* server will either be NL_INVALID (-1) or some low integer. hawknl
* sockets are mapped internally to real sockets, so don't be surprised
* if you get a socket back like 0.
*/
if (server == NL_INVALID){
throw InvalidPortException(port, nlGetSystemErrorStr(nlGetSystemError()));
}
Global::debug(1, "network") << "Successfully opened a socket: " << server << endl;
Util::Thread::acquireLock(&socketsLock);
open_sockets.push_back(server);
Util::Thread::releaseLock(&socketsLock);
return server;
}
Socket connect(string server, int port) throw (NetworkException){
NLaddress address;
- nlGetAddrFromName( server.c_str(), &address);
+ nlGetAddrFromName(server.c_str(), &address);
nlSetAddrPort(&address, port);
- Socket socket = open( 0 );
+ /* The port that this socket has opened will be immediately rebound to some
+ * other port by sock_connect, but we still need to call openReliable to get
+ * an NL_RELIABLE socket.
+ */
+ Socket socket = openReliable(0);
if (nlConnect(socket, &address) == NL_FALSE){
close(socket);
- throw NetworkException( "Could not connect" );
+ throw NetworkException("Could not connect");
}
return socket;
}
void close(Socket s){
Util::Thread::acquireLock(&socketsLock);
for (vector< Socket >::iterator it = open_sockets.begin(); it != open_sockets.end(); ){
if ( *it == s ){
Global::debug(1, "network") << "Closing socket " << s << endl;
nlClose(*it);
Global::debug(1, "network") << "Closed" << endl;
it = open_sockets.erase(it);
} else {
it++;
}
}
Util::Thread::releaseLock(&socketsLock);
}
void closeAll(){
Global::debug(1, "network") << "Closing all sockets" << std::endl;
Util::Thread::acquireLock(&socketsLock);
for (vector<Socket>::iterator it = open_sockets.begin(); it != open_sockets.end(); it++ ){
nlClose(*it);
}
open_sockets.clear();
Util::Thread::releaseLock(&socketsLock);
}
void init(){
nlInit();
nlSelectNetwork(NL_IP);
nlEnable(NL_BLOCKING_IO);
Util::Thread::initializeLock(&socketsLock);
// nlDisable( NL_BLOCKING_IO );
}
bool blocking(Socket s, bool b){
return nlSetSocketOpt(s, NL_BLOCKING_IO, b) == NL_TRUE;
}
void blocking(bool b){
if (b){
nlEnable(NL_BLOCKING_IO);
} else {
nlDisable(NL_BLOCKING_IO);
}
}
bool noDelay(Socket s, bool b){
return nlSetSocketOpt(s, NL_TCP_NO_DELAY, b) == NL_TRUE;
}
void listen( Socket s ) throw( NetworkException ){
if ( nlListen( s ) == NL_FALSE ){
throw CannotListenException( string(nlGetSystemErrorStr( nlGetSystemError() )) );
}
}
Socket accept( Socket s ) throw( NetworkException ){
Socket connection = nlAcceptConnection( s );
if ( connection == NL_INVALID ){
/*
if ( nlGetError() == NL_NO_PENDING ){
error = NO_CONNECTIONS_PENDING;
} else {
error = NETWORK_ERROR;
}
return s;
*/
if ( nlGetError() == NL_NO_PENDING ){
throw NoConnectionsPendingException();
}
throw NetworkException("Could not accept connection");
}
Util::Thread::acquireLock(&socketsLock);
open_sockets.push_back(connection);
Util::Thread::releaseLock(&socketsLock);
return connection;
}
void shutdown(){
nlShutdown();
}
#else
/* Dummy implementations */
char * dump16(char * where, int16_t length){
return where;
}
int8_t read8(Socket socket){
return 0;
}
int16_t read16(Socket socket){
return 0;
}
char * dumpStr(char * where, const std::string & str){
return where;
}
void readBytes(Socket socket, uint8_t * data, int length){
}
void sendBytes(Socket socket, const uint8_t * data, int length){
}
char * parseString(char * where, std::string * out, uint16_t length){
return where;
}
#endif
}
diff --git a/util/network/network.h b/util/network/network.h
index 581c1c16..58b228ee 100644
--- a/util/network/network.h
+++ b/util/network/network.h
@@ -1,125 +1,126 @@
#ifndef _paintown_network_h
#define _paintown_network_h
#include <stdint.h>
#ifdef HAVE_NETWORKING
#include "hawknl/nl.h"
#endif
#include <string>
#include <vector>
#include <exception>
namespace Network{
#ifdef HAVE_NETWORKING
typedef NLsocket Socket;
#else
typedef int Socket;
#endif
const int NO_CONNECTIONS_PENDING = 1;
const int NETWORK_ERROR = 2;
const int DATA_SIZE = 16;
class NetworkException: public std::exception{
public:
NetworkException( const std::string message = "" ):std::exception(),message(message){}
inline const std::string getMessage() const {
return message;
}
~NetworkException() throw();
protected:
inline void setMessage( const std::string & m ){
this->message = m;
}
private:
std::string message;
};
class NoConnectionsPendingException: public NetworkException{
public:
NoConnectionsPendingException(const std::string message = ""):
NetworkException(message){
}
};
class MessageEnd: public NetworkException {
public:
MessageEnd();
};
class InvalidPortException: public NetworkException{
public:
InvalidPortException( int port, const std::string message = "" );
};
class CannotListenException: public NetworkException{
public:
CannotListenException( const std::string message = "" ):
NetworkException( message ){
}
};
/*
template <class M>
int totalSize(const std::vector<M> & messages);
template <class M>
void dump(const std::vector<M> & messages, uint8_t * buffer );
*/
int8_t read8(Socket socket);
int16_t read16(Socket socket);
int32_t read32(Socket socket);
char * dump16(char * where, int16_t length);
void send16(Socket socket, int16_t length);
/* Reads a string by expecting the string to be terminated with a null byte */
std::string readStr(Socket socket, const uint16_t length);
/* This will send a string plus its null byte. If you just wanted to send the string
* without a null byte then use sendBytes(socket, str.c_str(), str.size())
*/
void sendStr(Socket socket, const std::string & str );
void sendBytes(Socket socket, const uint8_t * data, int length);
void readBytes(Socket socket, uint8_t * data, int length);
/* Copies the string plus its null byte to the `where' buffer.
* Returns a pointer that is where + str.size() + 1
*/
char * dumpStr(char * where, const std::string & str);
char * parse16(char * where, uint16_t * out);
/* Reads a string into 'out' from 'where' that is expected to be 'length' bytes.
* Right now the function lies, it just does *out = string(where) so if the actual
* string is longer than 'length' the returned pointer will be into the middle
* of the 'where' buffer.
*/
char * parseString(char * where, std::string * out, uint16_t length);
void init();
void shutdown();
/* Whether or not blocking is enabled by default for new sockets */
void blocking(bool b);
/* Enable/disable blocking for a specific socket */
bool blocking(Socket s, bool b);
/* Enable/disable NODELAY -- the Nagle algorithm for TCP */
bool noDelay(Socket s, bool b);
void listen(Socket s) throw (NetworkException);
Socket accept(Socket s) throw (NetworkException);
-Socket open(int port) throw (InvalidPortException);
-Socket connect( std::string server, int port ) throw (NetworkException);
+Socket openReliable(int port);
+Socket openUnreliable(int port);
+Socket connect(std::string server, int port) throw (NetworkException);
void close(Socket);
void closeAll();
static std::vector<Socket> open_sockets;
}
#endif

File Metadata

Mime Type
text/x-diff
Expires
Fri, Jun 19, 8:33 PM (1 w, 2 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
71850
Default Alt Text
(22 KB)

Event Timeline