diff options
author | Sven Gothel <[email protected]> | 2015-03-28 02:08:11 +0100 |
---|---|---|
committer | Sven Gothel <[email protected]> | 2015-03-28 02:08:11 +0100 |
commit | 450aa6f7df9e67dd256b86f94e65eaf707032aad (patch) | |
tree | 04aa207d84ddc8ca246d2573aaaf756b3ce8a0b5 /LibOVR/Src/Net | |
parent | 3c7b8a17e907f4ef2afd9f77db566a3f6179cbe4 (diff) | |
parent | 4207f9c279e832e3afcb3f5fc6cd8d84cb4cfe4c (diff) |
Merge branch 'vanilla_0.5.0.1' into jogamp_0.5.0.1
Conflicts:
LibOVR/Include/OVR_CAPI_0_5_0.h
LibOVR/Src/CAPI/CAPI_HMDState.cpp
LibOVR/Src/Displays/OVR_Win32_Dxgi_Display.h
LibOVR/Src/Kernel/OVR_System.cpp
LibOVR/Src/OVR_CAPI.cpp
LibOVR/Src/OVR_Profile.cpp
LibOVRKernel/Src/Kernel/OVR_ThreadsWinAPI.cpp
LibOVRKernel/Src/Kernel/OVR_Types.h
Diffstat (limited to 'LibOVR/Src/Net')
-rw-r--r-- | LibOVR/Src/Net/OVR_BitStream.cpp | 25 | ||||
-rw-r--r-- | LibOVR/Src/Net/OVR_BitStream.h | 17 | ||||
-rw-r--r-- | LibOVR/Src/Net/OVR_NetworkTypes.h | 2 | ||||
-rw-r--r-- | LibOVR/Src/Net/OVR_PacketizedTCPSocket.h | 4 | ||||
-rw-r--r-- | LibOVR/Src/Net/OVR_RPC1.cpp | 16 | ||||
-rw-r--r-- | LibOVR/Src/Net/OVR_RPC1.h | 15 | ||||
-rw-r--r-- | LibOVR/Src/Net/OVR_Session.cpp | 455 | ||||
-rwxr-xr-x[-rw-r--r--] | LibOVR/Src/Net/OVR_Session.h | 309 | ||||
-rw-r--r-- | LibOVR/Src/Net/OVR_Socket.h | 14 | ||||
-rw-r--r-- | LibOVR/Src/Net/OVR_Unix_Socket.cpp | 18 | ||||
-rw-r--r-- | LibOVR/Src/Net/OVR_Win32_Socket.cpp | 1226 | ||||
-rw-r--r-- | LibOVR/Src/Net/OVR_Win32_Socket.h | 301 |
12 files changed, 1297 insertions, 1105 deletions
diff --git a/LibOVR/Src/Net/OVR_BitStream.cpp b/LibOVR/Src/Net/OVR_BitStream.cpp index b565f22..054f871 100644 --- a/LibOVR/Src/Net/OVR_BitStream.cpp +++ b/LibOVR/Src/Net/OVR_BitStream.cpp @@ -107,6 +107,18 @@ BitStream::BitStream( char* _data, const unsigned int lengthInBytes, bool _copyD data = ( unsigned char* ) _data; } +void BitStream::WrapBuffer(unsigned char* _data, const unsigned int lengthInBytes) +{ + if (copyData && numberOfBitsAllocated > (BITSTREAM_STACK_ALLOCATION_SIZE << 3)) + OVR_FREE(data); // Use realloc and free so we are more efficient than delete and new for resizing + + numberOfBitsUsed = lengthInBytes << 3; + readOffset = 0; + copyData = false; + numberOfBitsAllocated = lengthInBytes << 3; + data = (unsigned char*)_data; +} + // Use this if you pass a pointer copy to the constructor (_copyData==false) and want to overallocate to prevent reallocation void BitStream::SetNumberOfBitsAllocated( const BitSize_t lengthInBits ) { @@ -183,8 +195,9 @@ void BitStream::Write( BitStream *bitStream, BitSize_t numberOfBits ) numberOfBitsUsed+=BYTES_TO_BITS(numBytes); } - while (numberOfBits-->0 && bitStream->readOffset + 1 <= bitStream->numberOfBitsUsed) + while (numberOfBits > 0 && bitStream->readOffset + 1 <= bitStream->numberOfBitsUsed) { + --numberOfBits; numberOfBitsMod8 = numberOfBitsUsed & 7; if ( numberOfBitsMod8 == 0 ) { @@ -974,14 +987,8 @@ void BitStream::AssertCopyData( void ) } bool BitStream::IsNetworkOrderInternal(void) { -#if defined(_PS3) || defined(__PS3__) || defined(SN_TARGET_PS3) - return true; -#elif defined(SN_TARGET_PSP2) - return false; -#else - static unsigned long htonlValue = htonl(12345); - return htonlValue == 12345; -#endif + static unsigned long htonlValue = htonl(12345); + return htonlValue == 12345; } void BitStream::ReverseBytes(unsigned char *inByteArray, unsigned char *inOutByteArray, const unsigned int length) { diff --git a/LibOVR/Src/Net/OVR_BitStream.h b/LibOVR/Src/Net/OVR_BitStream.h index b1ddc8f..4e2d2ef 100644 --- a/LibOVR/Src/Net/OVR_BitStream.h +++ b/LibOVR/Src/Net/OVR_BitStream.h @@ -29,9 +29,13 @@ limitations under the License. #define OVR_Bitstream_h #include <math.h> -#include "../Kernel/OVR_Types.h" -#include "../Kernel/OVR_Std.h" -#include "../Kernel/OVR_String.h" +#include "Kernel/OVR_Types.h" +#include "Kernel/OVR_Std.h" +#include "Kernel/OVR_String.h" + +#if defined(OVR_CC_MSVC) +#pragma warning(push) +#endif namespace OVR { namespace Net { @@ -76,6 +80,9 @@ public: /// Resets the bitstream for reuse. void Reset( void ); + // Releases the current data and points the bitstream at the provided buffer + void WrapBuffer(unsigned char* data, const unsigned int lengthInBytes); + /// \brief Bidirectional serialize/deserialize any integral type to/from a bitstream. /// \details Undefine __BITSTREAM_NATIVE_END if you need endian swapping. /// \param[in] writeToBitstream true to write from your data to this bitstream. False to read from this bitstream and write to your data @@ -1741,4 +1748,8 @@ BitStream& operator>>(BitStream& in, templateType& c) }} // OVR::Net +#if defined(OVR_CC_MSVC) +#pragma warning(pop) +#endif + #endif diff --git a/LibOVR/Src/Net/OVR_NetworkTypes.h b/LibOVR/Src/Net/OVR_NetworkTypes.h index 401f53a..c807bd0 100644 --- a/LibOVR/Src/Net/OVR_NetworkTypes.h +++ b/LibOVR/Src/Net/OVR_NetworkTypes.h @@ -28,7 +28,7 @@ limitations under the License. #ifndef OVR_NetworkTypes_h #define OVR_NetworkTypes_h -#include "../Kernel/OVR_Types.h" +#include "Kernel/OVR_Types.h" namespace OVR { namespace Net { diff --git a/LibOVR/Src/Net/OVR_PacketizedTCPSocket.h b/LibOVR/Src/Net/OVR_PacketizedTCPSocket.h index 8052bd3..e4bd3f3 100644 --- a/LibOVR/Src/Net/OVR_PacketizedTCPSocket.h +++ b/LibOVR/Src/Net/OVR_PacketizedTCPSocket.h @@ -29,8 +29,8 @@ limitations under the License. #define OVR_PacketizedTCPSocket_h #include "OVR_Socket.h" -#include "../Kernel/OVR_Allocator.h" -#include "../Kernel/OVR_Atomic.h" +#include "Kernel/OVR_Allocator.h" +#include "Kernel/OVR_Atomic.h" #ifdef OVR_OS_WIN32 #include "OVR_Win32_Socket.h" diff --git a/LibOVR/Src/Net/OVR_RPC1.cpp b/LibOVR/Src/Net/OVR_RPC1.cpp index 12afb09..b6137a5 100644 --- a/LibOVR/Src/Net/OVR_RPC1.cpp +++ b/LibOVR/Src/Net/OVR_RPC1.cpp @@ -26,7 +26,7 @@ limitations under the License. #include "OVR_RPC1.h" #include "OVR_BitStream.h" -#include "../Kernel/OVR_Threads.h" // Thread::MSleep +#include "Kernel/OVR_Threads.h" // Thread::MSleep #include "OVR_MessageIDTypes.h" namespace OVR { namespace Net { namespace Plugins { @@ -54,13 +54,12 @@ RPC1::RPC1() RPC1::~RPC1() { - slotHash.Clear(); delete blockingReturnValue; } -void RPC1::RegisterSlot(OVR::String sharedIdentifier, OVR::Observer<RPCSlot>* rpcSlotObserver ) +void RPC1::RegisterSlot(OVR::String sharedIdentifier, OVR::CallbackListener<RPCSlot>* rpcSlotObserver) { - slotHash.AddObserverToSubject(sharedIdentifier, rpcSlotObserver); + slotHash.AddListener(sharedIdentifier, rpcSlotObserver); } bool RPC1::RegisterBlockingFunction(OVR::String uniqueID, RPCDelegate blockingFunction) @@ -235,18 +234,15 @@ void RPC1::OnReceive(ReceivePayload *pPayload, ListenerReceiveResult *lrrOut) OVR::String sharedIdentifier; bsIn.Read(sharedIdentifier); - Observer<RPCSlot> *o = slotHash.GetSubject(sharedIdentifier); + CallbackEmitter<RPCSlot>* o = slotHash.GetKey(sharedIdentifier); if (o) { bsIn.AlignReadToByteBoundary(); - if (o) - { - OVR::Net::BitStream serializedParameters(bsIn.GetData() + bsIn.GetReadOffset()/8, bsIn.GetNumberOfUnreadBits()/8, false); + OVR::Net::BitStream serializedParameters(bsIn.GetData() + bsIn.GetReadOffset()/8, bsIn.GetNumberOfUnreadBits()/8, false); - o->Call(&serializedParameters, pPayload); - } + o->Call(&serializedParameters, pPayload); } } } diff --git a/LibOVR/Src/Net/OVR_RPC1.h b/LibOVR/Src/Net/OVR_RPC1.h index 6104ccf..6af2155 100644 --- a/LibOVR/Src/Net/OVR_RPC1.h +++ b/LibOVR/Src/Net/OVR_RPC1.h @@ -29,12 +29,12 @@ limitations under the License. #define OVR_Net_RPC_h #include "OVR_NetworkPlugin.h" -#include "../Kernel/OVR_Hash.h" -#include "../Kernel/OVR_String.h" +#include "Kernel/OVR_Hash.h" +#include "Kernel/OVR_String.h" #include "OVR_BitStream.h" -#include "../Kernel/OVR_Threads.h" -#include "../Kernel/OVR_Delegates.h" -#include "../Kernel//OVR_Observer.h" +#include "Kernel/OVR_Threads.h" +#include "Kernel/OVR_Delegates.h" +#include "Kernel/OVR_Callbacks.h" namespace OVR { namespace Net { namespace Plugins { @@ -55,7 +55,7 @@ public: /// \param[in] sharedIdentifier A string to identify the slot. Recommended to be the same as the name of the function. /// \param[in] functionPtr Pointer to the function. /// \param[in] callPriority Slots are called by order of the highest callPriority first. For slots with the same priority, they are called in the order they are registered - void RegisterSlot(OVR::String sharedIdentifier, OVR::Observer<RPCSlot> *rpcSlotObserver); + void RegisterSlot(OVR::String sharedIdentifier, CallbackListener<RPCSlot>* rpcSlotListener); /// \brief Same as \a RegisterFunction, but is called with CallBlocking() instead of Call() and returns a value to the caller bool RegisterBlockingFunction(OVR::String uniqueID, RPCDelegate blockingFunction); @@ -88,7 +88,8 @@ protected: virtual void OnConnected(Connection* conn); Hash< String, RPCDelegate, String::HashFunctor > registeredBlockingFunctions; - ObserverHash< RPCSlot > slotHash; + + CallbackHash< RPCSlot > slotHash; // Synchronization for RPC caller Lock singleRPCLock; diff --git a/LibOVR/Src/Net/OVR_Session.cpp b/LibOVR/Src/Net/OVR_Session.cpp index 508f0c9..4049c6c 100644 --- a/LibOVR/Src/Net/OVR_Session.cpp +++ b/LibOVR/Src/Net/OVR_Session.cpp @@ -26,36 +26,95 @@ limitations under the License. #include "OVR_Session.h" #include "OVR_PacketizedTCPSocket.h" -#include "../Kernel/OVR_Log.h" -#include "../Service/Service_NetSessionCommon.h" +#include "Kernel/OVR_Log.h" +#include "Service/Service_NetSessionCommon.h" namespace OVR { namespace Net { +// The SDK version requested by the user. +SDKVersion RuntimeSDKVersion; + + //----------------------------------------------------------------------------- // Protocol -static const char* OfficialHelloString = "OculusVR_Hello"; +static const char* OfficialHelloString = "OculusVR_Hello"; static const char* OfficialAuthorizedString = "OculusVR_Authorized"; -void RPC_C2S_Hello::Generate(Net::BitStream* bs) +bool RPC_C2S_Hello::Serialize(bool writeToBitstream, Net::BitStream* bs) +{ + bs->Serialize(writeToBitstream, HelloString); + bs->Serialize(writeToBitstream, MajorVersion); + bs->Serialize(writeToBitstream, MinorVersion); + if (!bs->Serialize(writeToBitstream, PatchVersion)) + return false; + + // If an older client is connecting to us, + if (!writeToBitstream && (MajorVersion * 100) + (MinorVersion * 10) + PatchVersion < 121) + { + // The following was version code was added to RPC version 1.2 + // without bumping it up to 1.3 and introducing an incompatibility. + // We can do this because an older server will not read this additional data. + return true; + } + + bs->Serialize(writeToBitstream, CodeVersion.ProductVersion); + bs->Serialize(writeToBitstream, CodeVersion.MajorVersion); + bs->Serialize(writeToBitstream, CodeVersion.MinorVersion); + bs->Serialize(writeToBitstream, CodeVersion.RequestedMinorVersion); + bs->Serialize(writeToBitstream, CodeVersion.PatchVersion); + bs->Serialize(writeToBitstream, CodeVersion.BuildNumber); + return bs->Serialize(writeToBitstream, CodeVersion.FeatureVersion); +} + +void RPC_C2S_Hello::ClientGenerate(Net::BitStream* bs) { RPC_C2S_Hello hello; - hello.HelloString = OfficialHelloString; + hello.HelloString = OfficialHelloString; hello.MajorVersion = RPCVersion_Major; hello.MinorVersion = RPCVersion_Minor; hello.PatchVersion = RPCVersion_Patch; - hello.Serialize(bs); + OVR_ASSERT(OVR::Net::RuntimeSDKVersion.ProductVersion != UINT16_MAX); + hello.CodeVersion = OVR::Net::RuntimeSDKVersion; // This should have been set to a value earlier in the first steps of ovr initialization. + hello.Serialize(true, bs); } -bool RPC_C2S_Hello::Validate() +bool RPC_C2S_Hello::ServerValidate() { + // Server checks the protocol version return MajorVersion == RPCVersion_Major && MinorVersion <= RPCVersion_Minor && HelloString.CompareNoCase(OfficialHelloString) == 0; } -void RPC_S2C_Authorization::Generate(Net::BitStream* bs, String errorString) +bool RPC_S2C_Authorization::Serialize(bool writeToBitstream, Net::BitStream* bs) +{ + bs->Serialize(writeToBitstream, AuthString); + bs->Serialize(writeToBitstream, MajorVersion); + bs->Serialize(writeToBitstream, MinorVersion); + if (!bs->Serialize(writeToBitstream, PatchVersion)) + return false; + + // If an older client is connecting to us, + if (!writeToBitstream && (MajorVersion * 100) + (MinorVersion * 10) + PatchVersion < 121) + { + // The following was version code was added to RPC version 1.2 + // without bumping it up to 1.3 and introducing an incompatibility. + // We can do this because an older server will not read this additional data. + return true; + } + + bs->Serialize(writeToBitstream, CodeVersion.ProductVersion); + bs->Serialize(writeToBitstream, CodeVersion.MajorVersion); + bs->Serialize(writeToBitstream, CodeVersion.MinorVersion); + bs->Serialize(writeToBitstream, CodeVersion.RequestedMinorVersion); + bs->Serialize(writeToBitstream, CodeVersion.PatchVersion); + bs->Serialize(writeToBitstream, CodeVersion.BuildNumber); + return bs->Serialize(writeToBitstream, CodeVersion.FeatureVersion); +} + +void RPC_S2C_Authorization::ServerGenerate(Net::BitStream* bs, String errorString) { RPC_S2C_Authorization auth; if (errorString.IsEmpty()) @@ -69,16 +128,33 @@ void RPC_S2C_Authorization::Generate(Net::BitStream* bs, String errorString) auth.MajorVersion = RPCVersion_Major; auth.MinorVersion = RPCVersion_Minor; auth.PatchVersion = RPCVersion_Patch; - auth.Serialize(bs); + // Leave CurrentSDKVersion as it is. + auth.Serialize(true, bs); } -bool RPC_S2C_Authorization::Validate() +bool RPC_S2C_Authorization::ClientValidate() { return AuthString.CompareNoCase(OfficialAuthorizedString) == 0; } //----------------------------------------------------------------------------- +// SingleProcess + +static bool SingleProcess = false; + +void Session::SetSingleProcess(bool enable) +{ + SingleProcess = enable; +} + +bool Session::IsSingleProcess() +{ + return SingleProcess; +} + + +//----------------------------------------------------------------------------- // Session void Session::Shutdown() @@ -111,29 +187,25 @@ void Session::Shutdown() SessionResult Session::Listen(ListenerDescription* pListenerDescription) { - if (pListenerDescription->Transport == TransportType_PacketizedTCP) - { - BerkleyListenerDescription* bld = (BerkleyListenerDescription*)pListenerDescription; - TCPSocket* tcpSocket = (TCPSocket*)bld->BoundSocketToListenWith.GetPtr(); + if (pListenerDescription->Transport == TransportType_PacketizedTCP) + { + BerkleyListenerDescription* bld = (BerkleyListenerDescription*)pListenerDescription; + TCPSocket* tcpSocket = (TCPSocket*)bld->BoundSocketToListenWith.GetPtr(); if (tcpSocket->Listen() < 0) { return SessionResult_ListenFailure; } - Lock::Locker locker(&SocketListenersLock); + Lock::Locker locker(&SocketListenersLock); SocketListeners.PushBack(tcpSocket); - } - else if (pListenerDescription->Transport == TransportType_Loopback) - { - HasLoopbackListener = true; - } + } else { OVR_ASSERT(false); } - return SessionResult_OK; + return SessionResult_OK; } SessionResult Session::Connect(ConnectParameters *cp) @@ -153,6 +225,28 @@ SessionResult Session::Connect(ConnectParameters *cp) return SessionResult_AlreadyConnected; } + // If we are already connected, don't create a duplicate connection + if (FullConnections.GetSizeI() > 0) + { + return SessionResult_AlreadyConnected; + } + + // If we are already connecting, don't create a duplicate connection + const int count = AllConnections.GetSizeI(); + for (int i = 0; i < count; ++i) + { + Connection* arrayItem = AllConnections[i].GetPtr(); + + OVR_ASSERT(arrayItem); + if (arrayItem) { + if (arrayItem->State == Client_ConnectedWait + || arrayItem->State == Client_Connecting) + { + return SessionResult_ConnectInProgress; + } + } + } + TCPSocketBase* tcpSock = (TCPSocketBase*)cp2->BoundSocketToConnectWith.GetPtr(); int ret = tcpSock->Connect(&cp2->RemoteAddress); @@ -174,7 +268,6 @@ SessionResult Session::Connect(ConnectParameters *cp) c->SetState(Client_Connecting); AllConnections.PushBack(c); - } if (cp2->Blocking) @@ -182,11 +275,12 @@ SessionResult Session::Connect(ConnectParameters *cp) c->WaitOnConnecting(); } - if (c->State == State_Connected) + EConnectionState state = c->State; + if (state == State_Connected) { return SessionResult_OK; } - else if (c->State == Client_Connecting) + else if (state == Client_Connecting) { return SessionResult_ConnectInProgress; } @@ -195,49 +289,33 @@ SessionResult Session::Connect(ConnectParameters *cp) return SessionResult_ConnectFailure; } } - else if (cp->Transport == TransportType_Loopback) - { - if (HasLoopbackListener) - { - Ptr<Connection> c = AllocConnection(cp->Transport); - if (!c) - { - return SessionResult_ConnectFailure; - } - - c->Transport = cp->Transport; - c->SetState(State_Connected); - - { - Lock::Locker locker(&ConnectionsLock); - AllConnections.PushBack(c); - } - - invokeSessionEvent(&SessionListener::OnConnectionRequestAccepted, c); - } - else - { - OVR_ASSERT(false); - } - } else { OVR_ASSERT(false); } - return SessionResult_OK; + return SessionResult_OK; } +static Session* SingleProcessServer = nullptr; + SessionResult Session::ListenPTCP(OVR::Net::BerkleyBindParameters *bbp) { - Ptr<PacketizedTCPSocket> listenSocket = *new OVR::Net::PacketizedTCPSocket(); + if (Session::IsSingleProcess()) + { + // Do not actually listen on a socket. + SingleProcessServer = this; + return SessionResult_OK; + } + + Ptr<PacketizedTCPSocket> listenSocket = *new OVR::Net::PacketizedTCPSocket(); if (listenSocket->Bind(bbp) == INVALID_SOCKET) { return SessionResult_BindFailure; } - BerkleyListenerDescription bld; - bld.BoundSocketToListenWith = listenSocket.GetPtr(); + BerkleyListenerDescription bld; + bld.BoundSocketToListenWith = listenSocket.GetPtr(); bld.Transport = TransportType_PacketizedTCP; return Listen(&bld); @@ -245,16 +323,46 @@ SessionResult Session::ListenPTCP(OVR::Net::BerkleyBindParameters *bbp) SessionResult Session::ConnectPTCP(OVR::Net::BerkleyBindParameters* bbp, SockAddr* remoteAddress, bool blocking) { + if (Session::IsSingleProcess()) + { + OVR_ASSERT(SingleProcessServer); // ListenPTCP() must be called before ConnectPTCP() + + SingleProcessServer->SingleTargetSession = this; + SingleTargetSession = SingleProcessServer; + + Ptr<PacketizedTCPSocket> s = *new PacketizedTCPSocket; + SockAddr sa; + sa.Set("::1", 10101, SOCK_STREAM); + + Ptr<Connection> newConnection = AllocConnection(TransportType_PacketizedTCP); + if (!newConnection) + { + return SessionResult_ConnectFailure; + } + + PacketizedTCPConnection* c = (PacketizedTCPConnection*)newConnection.GetPtr(); + c->pSocket = s; + c->Address = &sa; + c->Transport = TransportType_PacketizedTCP; + c->SetState(Client_Connecting); + AllConnections.PushBack(c); + + SingleTargetSession->TCP_OnAccept(s, &sa, INVALID_SOCKET); + TCP_OnConnected(s); + + return SessionResult_OK; + } + ConnectParametersBerkleySocket cp(NULL, remoteAddress, blocking, TransportType_PacketizedTCP); Ptr<PacketizedTCPSocket> connectSocket = *new PacketizedTCPSocket(); - cp.BoundSocketToConnectWith = connectSocket.GetPtr(); + cp.BoundSocketToConnectWith = connectSocket.GetPtr(); if (connectSocket->Bind(bbp) == INVALID_SOCKET) { return SessionResult_BindFailure; } - return Connect(&cp); + return Connect(&cp); } Ptr<PacketizedTCPConnection> Session::findConnectionBySockAddr(SockAddr* address) @@ -280,47 +388,26 @@ Ptr<PacketizedTCPConnection> Session::findConnectionBySockAddr(SockAddr* address int Session::Send(SendParameters *payload) { - if (payload->pConnection->Transport == TransportType_Loopback) - { - Lock::Locker locker(&SessionListenersLock); - - const int count = SessionListeners.GetSizeI(); - for (int i = 0; i < count; ++i) - { - SessionListener* sl = SessionListeners[i]; - - // FIXME: This looks like it needs to be reviewed at some point.. - ReceivePayload rp; - rp.Bytes = payload->Bytes; - rp.pConnection = payload->pConnection; - rp.pData = (uint8_t*)payload->pData; // FIXME - ListenerReceiveResult lrr = LRR_CONTINUE; - sl->OnReceive(&rp, &lrr); - if (lrr == LRR_RETURN) - { - return payload->Bytes; - } - else if (lrr == LRR_BREAK) - { - break; - } - } - - return payload->Bytes; - } - else if (payload->pConnection->Transport == TransportType_PacketizedTCP) - { - PacketizedTCPConnection* conn = (PacketizedTCPConnection*)payload->pConnection.GetPtr(); - - return conn->pSocket->Send(payload->pData, payload->Bytes); - } - else + if (payload->pConnection->Transport == TransportType_PacketizedTCP) { - OVR_ASSERT(false); + if (Session::IsSingleProcess()) + { + OVR_ASSERT(SingleTargetSession->AllConnections.GetSizeI() > 0); + PacketizedTCPConnection* conn = (PacketizedTCPConnection*)SingleTargetSession->AllConnections[0].GetPtr(); + SingleTargetSession->TCP_OnRecv(conn->pSocket, (uint8_t*)payload->pData, payload->Bytes); + return payload->Bytes; + } + else + { + PacketizedTCPConnection* conn = (PacketizedTCPConnection*)payload->pConnection.GetPtr(); + return conn->pSocket->Send(payload->pData, payload->Bytes); + } } + OVR_ASSERT(false); // Should not reach here return 0; } + void Session::Broadcast(BroadcastParameters *payload) { SendParameters sp; @@ -338,21 +425,29 @@ void Session::Broadcast(BroadcastParameters *payload) } } } -// DO NOT CALL Poll() FROM MULTIPLE THREADS due to allBlockingTcpSockets being a member + +// DO NOT CALL Poll() FROM MULTIPLE THREADS due to AllBlockingTcpSockets being a member void Session::Poll(bool listeners) { - allBlockingTcpSockets.Clear(); + if (Net::Session::IsSingleProcess()) + { + // Spend a lot of time sleeping in single process mode + Thread::MSleep(100); + return; + } - if (listeners) - { - Lock::Locker locker(&SocketListenersLock); + AllBlockingTcpSockets.Clear(); + + if (listeners) + { + Lock::Locker locker(&SocketListenersLock); const int listenerCount = SocketListeners.GetSizeI(); for (int i = 0; i < listenerCount; ++i) - { - allBlockingTcpSockets.PushBack(SocketListeners[i]); - } - } + { + AllBlockingTcpSockets.PushBack(SocketListeners[i]); + } + } { Lock::Locker locker(&ConnectionsLock); @@ -366,7 +461,7 @@ void Session::Poll(bool listeners) { PacketizedTCPConnection* ptcp = (PacketizedTCPConnection*)arrayItem; - allBlockingTcpSockets.PushBack(ptcp->pSocket); + AllBlockingTcpSockets.PushBack(ptcp->pSocket); } else { @@ -375,15 +470,15 @@ void Session::Poll(bool listeners) } } - const int count = allBlockingTcpSockets.GetSizeI(); - if (count > 0) - { + const int count = AllBlockingTcpSockets.GetSizeI(); + if (count > 0) + { TCPSocketPollState state; // Add all the sockets for polling, for (int i = 0; i < count; ++i) { - Net::TCPSocket* sock = allBlockingTcpSockets[i].GetPtr(); + Net::TCPSocket* sock = AllBlockingTcpSockets[i].GetPtr(); // If socket handle is invalid, if (sock->GetSocketHandle() == INVALID_SOCKET) @@ -399,20 +494,20 @@ void Session::Poll(bool listeners) } // If polling returns with an event, - if (state.Poll(allBlockingTcpSockets[0]->GetBlockingTimeoutUsec(), allBlockingTcpSockets[0]->GetBlockingTimeoutSec())) + if (state.Poll(AllBlockingTcpSockets[0]->GetBlockingTimeoutUsec(), AllBlockingTcpSockets[0]->GetBlockingTimeoutSec())) { // Handle any events for each socket for (int i = 0; i < count; ++i) { - state.HandleEvent(allBlockingTcpSockets[i], this); + state.HandleEvent(AllBlockingTcpSockets[i], this); } } - } + } } void Session::AddSessionListener(SessionListener* se) { - Lock::Locker locker(&SessionListenersLock); + Lock::Locker locker(&SessionListenersLock); const int count = SessionListeners.GetSizeI(); for (int i = 0; i < count; ++i) @@ -425,36 +520,35 @@ void Session::AddSessionListener(SessionListener* se) } SessionListeners.PushBack(se); - se->OnAddedToSession(this); + se->OnAddedToSession(this); } void Session::RemoveSessionListener(SessionListener* se) { - Lock::Locker locker(&SessionListenersLock); + Lock::Locker locker(&SessionListenersLock); const int count = SessionListeners.GetSizeI(); - for (int i = 0; i < count; ++i) - { + for (int i = 0; i < count; ++i) + { if (SessionListeners[i] == se) - { + { se->OnRemovedFromSession(this); SessionListeners.RemoveAtUnordered(i); break; - } - } + } + } } -SInt32 Session::GetActiveSocketsCount() + +int Session::GetActiveSocketsCount() { - Lock::Locker locker1(&SocketListenersLock); - Lock::Locker locker2(&ConnectionsLock); - return SocketListeners.GetSize() + AllConnections.GetSize()>0; + return SocketListeners.GetSizeI() + AllConnections.GetSizeI(); } + Ptr<Connection> Session::AllocConnection(TransportType transport) { switch (transport) { - case TransportType_Loopback: return *new Connection(); case TransportType_TCP: return *new TCPConnection(); case TransportType_PacketizedTCP: return *new PacketizedTCPConnection(); default: @@ -511,14 +605,14 @@ int Session::invokeSessionListeners(ReceivePayload* rp) void Session::TCP_OnRecv(Socket* pSocket, uint8_t* pData, int bytesRead) { - // KevinJ: 9/2/2014 Fix deadlock - Watchdog calls Broadcast(), which locks ConnectionsLock(). - // Lock::Locker locker(&ConnectionsLock); + // KevinJ: 9/2/2014 Fix deadlock - Watchdog calls Broadcast(), which locks ConnectionsLock(). + // Lock::Locker locker(&ConnectionsLock); // Look for the connection in the full connection list first int connIndex; - ConnectionsLock.DoLock(); + ConnectionsLock.DoLock(); Ptr<PacketizedTCPConnection> conn = findConnectionBySocket(AllConnections, pSocket, &connIndex); - ConnectionsLock.Unlock(); + ConnectionsLock.Unlock(); if (conn) { if (conn->State == State_Connected) @@ -537,8 +631,8 @@ void Session::TCP_OnRecv(Socket* pSocket, uint8_t* pData, int bytesRead) BitStream bsIn((char*)pData, bytesRead, false); RPC_S2C_Authorization auth; - if (!auth.Deserialize(&bsIn) || - !auth.Validate()) + if (!auth.Serialize(false, &bsIn) || + !auth.ClientValidate()) { LogError("{ERR-001} [Session] REJECTED: OVRService did not authorize us: %s", auth.AuthString.ToCStr()); @@ -551,16 +645,18 @@ void Session::TCP_OnRecv(Socket* pSocket, uint8_t* pData, int bytesRead) conn->RemoteMajorVersion = auth.MajorVersion; conn->RemoteMinorVersion = auth.MinorVersion; conn->RemotePatchVersion = auth.PatchVersion; + conn->RemoteCodeVersion = auth.CodeVersion; // Mark as connected conn->SetState(State_Connected); - ConnectionsLock.DoLock(); - int connIndex2; - if (findConnectionBySocket(AllConnections, pSocket, &connIndex2)==conn && findConnectionBySocket(FullConnections, pSocket, &connIndex2)==NULL) - { - FullConnections.PushBack(conn); - } - ConnectionsLock.Unlock(); + ConnectionsLock.DoLock(); + int connIndex2; + if (findConnectionBySocket(AllConnections, pSocket, &connIndex2)==conn && findConnectionBySocket(FullConnections, pSocket, &connIndex2)==NULL) + { + FullConnections.PushBack(conn); + HaveFullConnections.store(true, std::memory_order_relaxed); + } + ConnectionsLock.Unlock(); invokeSessionEvent(&SessionListener::OnConnectionRequestAccepted, conn); } } @@ -570,41 +666,59 @@ void Session::TCP_OnRecv(Socket* pSocket, uint8_t* pData, int bytesRead) BitStream bsIn((char*)pData, bytesRead, false); RPC_C2S_Hello hello; - if (!hello.Deserialize(&bsIn) || - !hello.Validate()) + if (!hello.Serialize(false, &bsIn) || + !hello.ServerValidate()) { - LogError("{ERR-002} [Session] REJECTED: Rift application is using an incompatible version %d.%d.%d (my version=%d.%d.%d)", - hello.MajorVersion, hello.MinorVersion, hello.PatchVersion, - RPCVersion_Major, RPCVersion_Minor, RPCVersion_Patch); + LogError("{ERR-002} [Session] REJECTED: Rift application is using an incompatible version %d.%d.%d, feature version %d (my version=%d.%d.%d, feature version %d)", + hello.MajorVersion, hello.MinorVersion, hello.PatchVersion, hello.CodeVersion.FeatureVersion, + RPCVersion_Major, RPCVersion_Minor, RPCVersion_Patch, OVR_FEATURE_VERSION); conn->SetState(State_Zombie); // Send auth response BitStream bsOut; - RPC_S2C_Authorization::Generate(&bsOut, "Incompatible protocol version. Please make sure your OVRService and SDK are both up to date."); - conn->pSocket->Send(bsOut.GetData(), bsOut.GetNumberOfBytesUsed()); + RPC_S2C_Authorization::ServerGenerate(&bsOut, "Incompatible protocol version. Please make sure your OVRService and SDK are both up to date."); + + SendParameters sp; + sp.Bytes = bsOut.GetNumberOfBytesUsed(); + sp.pData = bsOut.GetData(); + sp.pConnection = conn; + Send(&sp); } else { + if (hello.CodeVersion.FeatureVersion != OVR_FEATURE_VERSION) + { + LogError("[Session] WARNING: Rift application is using a different feature version than the server (server version = %d, app version = %d)", + OVR_FEATURE_VERSION, hello.CodeVersion.FeatureVersion); + } + // Read remote version conn->RemoteMajorVersion = hello.MajorVersion; conn->RemoteMinorVersion = hello.MinorVersion; conn->RemotePatchVersion = hello.PatchVersion; + conn->RemoteCodeVersion = hello.CodeVersion; // Send auth response BitStream bsOut; - RPC_S2C_Authorization::Generate(&bsOut); - conn->pSocket->Send(bsOut.GetData(), bsOut.GetNumberOfBytesUsed()); + RPC_S2C_Authorization::ServerGenerate(&bsOut); + + SendParameters sp; + sp.Bytes = bsOut.GetNumberOfBytesUsed(); + sp.pData = bsOut.GetData(); + sp.pConnection = conn; + Send(&sp); // Mark as connected conn->SetState(State_Connected); - ConnectionsLock.DoLock(); - int connIndex2; - if (findConnectionBySocket(AllConnections, pSocket, &connIndex2)==conn && findConnectionBySocket(FullConnections, pSocket, &connIndex2)==NULL) - { - FullConnections.PushBack(conn); - } - ConnectionsLock.Unlock(); + ConnectionsLock.DoLock(); + int connIndex2; + if (findConnectionBySocket(AllConnections, pSocket, &connIndex2)==conn && findConnectionBySocket(FullConnections, pSocket, &connIndex2)==NULL) + { + FullConnections.PushBack(conn); + HaveFullConnections.store(true, std::memory_order_relaxed); + } + ConnectionsLock.Unlock(); invokeSessionEvent(&SessionListener::OnNewIncomingConnection, conn); } @@ -618,10 +732,10 @@ void Session::TCP_OnRecv(Socket* pSocket, uint8_t* pData, int bytesRead) void Session::TCP_OnClosed(TCPSocket* s) { - Lock::Locker locker(&ConnectionsLock); + Lock::Locker locker(&ConnectionsLock); // If found in the full connection list, - int connIndex; + int connIndex = 0; Ptr<PacketizedTCPConnection> conn = findConnectionBySocket(AllConnections, s, &connIndex); if (conn) { @@ -631,6 +745,10 @@ void Session::TCP_OnClosed(TCPSocket* s) if (findConnectionBySocket(FullConnections, s, &connIndex)) { FullConnections.RemoveAtUnordered(connIndex); + if (FullConnections.GetSizeI() < 1) + { + HaveFullConnections.store(false, std::memory_order_relaxed); + } } // Generate an appropriate event for the current state @@ -659,23 +777,23 @@ void Session::TCP_OnClosed(TCPSocket* s) void Session::TCP_OnAccept(TCPSocket* pListener, SockAddr* pSockAddr, SocketHandle newSock) { OVR_UNUSED(pListener); - OVR_ASSERT(pListener->Transport == TransportType_PacketizedTCP); + Ptr<PacketizedTCPSocket> newSocket = *new PacketizedTCPSocket(newSock, false); + OVR_ASSERT(pListener->Transport == TransportType_PacketizedTCP); - Ptr<PacketizedTCPSocket> newSocket = *new PacketizedTCPSocket(newSock, false); // If pSockAddr is not localhost, then close newSock - if (pSockAddr->IsLocalhost()==false) + if (!pSockAddr->IsLocalhost()) { newSocket->Close(); return; } - if (newSocket) - { - Ptr<Connection> b = AllocConnection(TransportType_PacketizedTCP); - Ptr<PacketizedTCPConnection> c = (PacketizedTCPConnection*)b.GetPtr(); - c->pSocket = newSocket; - c->Address = *pSockAddr; + if (newSocket) + { + Ptr<Connection> b = AllocConnection(TransportType_PacketizedTCP); + Ptr<PacketizedTCPConnection> c = (PacketizedTCPConnection*)b.GetPtr(); + c->pSocket = newSocket; + c->Address = *pSockAddr; c->State = Server_ConnectedWait; { @@ -684,7 +802,7 @@ void Session::TCP_OnAccept(TCPSocket* pListener, SockAddr* pSockAddr, SocketHand } // Server does not send the first packet. It waits for the client to send its version - } + } } void Session::TCP_OnConnected(TCPSocket *s) @@ -697,13 +815,18 @@ void Session::TCP_OnConnected(TCPSocket *s) { OVR_ASSERT(conn->State == Client_Connecting); + // Just update state but do not generate any notifications yet + conn->SetState(Client_ConnectedWait); + // Send hello message BitStream bsOut; - RPC_C2S_Hello::Generate(&bsOut); - conn->pSocket->Send(bsOut.GetData(), bsOut.GetNumberOfBytesUsed()); + RPC_C2S_Hello::ClientGenerate(&bsOut); - // Just update state but do not generate any notifications yet - conn->State = Client_ConnectedWait; + SendParameters sp; + sp.Bytes = bsOut.GetNumberOfBytesUsed(); + sp.pData = bsOut.GetData(); + sp.pConnection = conn; + Send(&sp); } } diff --git a/LibOVR/Src/Net/OVR_Session.h b/LibOVR/Src/Net/OVR_Session.h index b6ef07b..3f329f9 100644..100755 --- a/LibOVR/Src/Net/OVR_Session.h +++ b/LibOVR/Src/Net/OVR_Session.h @@ -28,12 +28,16 @@ limitations under the License. #ifndef OVR_Session_h #define OVR_Session_h +#include <atomic> + +#include <OVR_Version.h> #include "OVR_Socket.h" #include "OVR_PacketizedTCPSocket.h" -#include "../Kernel/OVR_Array.h" -#include "../Kernel/OVR_Threads.h" -#include "../Kernel/OVR_Atomic.h" -#include "../Kernel/OVR_RefCount.h" +#include "Kernel/OVR_Array.h" +#include "Kernel/OVR_Threads.h" +#include "Kernel/OVR_RefCount.h" +#include <stdint.h> + namespace OVR { namespace Net { @@ -45,48 +49,82 @@ class Session; // // Please update changelog below: // 1.0.0 - [SDK 0.4.0] Initial version (July 21, 2014) -// 1.1.0 - Add Get/SetDriverMode_1, HMDCountUpdate_1 -// Version mismatch results (July 28, 2014) +// 1.1.0 - [SDK 0.4.1] Add Get/SetDriverMode_1, HMDCountUpdate_1 Version mismatch results (July 28, 2014) +// 1.2.0 - [SDK 0.4.4] +// 1.2.1 - [SDK 0.5.0] Added DyLib model and SDKVersion +// 1.3.0 - [SDK 0.5.0] Multiple shared memory regions for different objects //----------------------------------------------------------------------------- static const uint16_t RPCVersion_Major = 1; // MAJOR version when you make incompatible API changes, -static const uint16_t RPCVersion_Minor = 2; // MINOR version when you add functionality in a backwards-compatible manner, and +static const uint16_t RPCVersion_Minor = 3; // MINOR version when you add functionality in a backwards-compatible manner, and static const uint16_t RPCVersion_Patch = 0; // PATCH version when you make backwards-compatible bug fixes. +#define OVR_FEATURE_VERSION 0 + + +struct SDKVersion +{ + uint16_t ProductVersion; // CAPI DLL product number, 0 before first consumer release + uint16_t MajorVersion; // CAPI DLL version major number + uint16_t MinorVersion; // CAPI DLL version minor number + uint16_t RequestedMinorVersion; // Number provided by game in ovr_Initialize() arguments + uint16_t PatchVersion; // CAPI DLL version patch number + uint16_t BuildNumber; // Number increments per build + uint16_t FeatureVersion; // CAPI DLL feature version number + + SDKVersion() + { + Reset(); + } + + void Reset() + { + ProductVersion = MajorVersion = MinorVersion = UINT16_MAX; + RequestedMinorVersion = PatchVersion = BuildNumber = UINT16_MAX; + FeatureVersion = UINT16_MAX; + } + + void SetCurrent() + { + ProductVersion = OVR_PRODUCT_VERSION; + MajorVersion = OVR_MAJOR_VERSION; + MinorVersion = OVR_MINOR_VERSION; + RequestedMinorVersion = OVR_MINOR_VERSION; + PatchVersion = OVR_PATCH_VERSION; + BuildNumber = OVR_BUILD_NUMBER; + FeatureVersion = OVR_FEATURE_VERSION; + } +}; + +// This is the version that the OVR_CAPI client passes on to the server. It's a global variable +// because it needs to be initialized in ovr_Initialize but read in the OVR_Session module. +// This variable exists as a global in the server but it has no meaning. +extern SDKVersion RuntimeSDKVersion; + + // Client starts communication by sending its version number. struct RPC_C2S_Hello { RPC_C2S_Hello() : MajorVersion(0), MinorVersion(0), - PatchVersion(0) + PatchVersion(0), + CodeVersion() { + CodeVersion.SetCurrent(); } String HelloString; - // Client version info + // Client protocol version info uint16_t MajorVersion, MinorVersion, PatchVersion; - void Serialize(Net::BitStream* bs) - { - bs->Write(HelloString); - bs->Write(MajorVersion); - bs->Write(MinorVersion); - bs->Write(PatchVersion); - } - - bool Deserialize(Net::BitStream* bs) - { - bs->Read(HelloString); - bs->Read(MajorVersion); - bs->Read(MinorVersion); - return bs->Read(PatchVersion); - } - - static void Generate(Net::BitStream* bs); + // Client runtime code version info + SDKVersion CodeVersion; - bool Validate(); + bool Serialize(bool writeToBitstream, Net::BitStream* bs); + static void ClientGenerate(Net::BitStream* bs); + bool ServerValidate(); }; // Server responds with an authorization accepted message, including the server's version number @@ -95,8 +133,10 @@ struct RPC_S2C_Authorization RPC_S2C_Authorization() : MajorVersion(0), MinorVersion(0), - PatchVersion(0) + PatchVersion(0), + CodeVersion() { + CodeVersion.SetCurrent(); } String AuthString; @@ -104,25 +144,13 @@ struct RPC_S2C_Authorization // Server version info uint16_t MajorVersion, MinorVersion, PatchVersion; - void Serialize(Net::BitStream* bs) - { - bs->Write(AuthString); - bs->Write(MajorVersion); - bs->Write(MinorVersion); - bs->Write(PatchVersion); - } + // The SDK version that the server was built with. + // There's no concept of the server requesting an SDK version like the client does. + SDKVersion CodeVersion; - bool Deserialize(Net::BitStream* bs) - { - bs->Read(AuthString); - bs->Read(MajorVersion); - bs->Read(MinorVersion); - return bs->Read(PatchVersion); - } - - static void Generate(Net::BitStream* bs, String errorString = ""); - - bool Validate(); + bool Serialize(bool writeToBitstream, Net::BitStream* bs); + static void ServerGenerate(Net::BitStream* bs, String errorString = ""); + bool ClientValidate(); }; @@ -130,10 +158,10 @@ struct RPC_S2C_Authorization // Result of a session function enum SessionResult { - SessionResult_OK, - SessionResult_BindFailure, - SessionResult_ListenFailure, - SessionResult_ConnectFailure, + SessionResult_OK, + SessionResult_BindFailure, + SessionResult_ListenFailure, + SessionResult_ConnectFailure, SessionResult_ConnectInProgress, SessionResult_AlreadyConnected, }; @@ -166,10 +194,11 @@ public: State(State_Zombie), RemoteMajorVersion(0), RemoteMinorVersion(0), - RemotePatchVersion(0) + RemotePatchVersion(0), + RemoteCodeVersion() { } - virtual ~Connection() // Allow delete from base + virtual ~Connection() // Allow delete from base { } @@ -180,9 +209,10 @@ public: EConnectionState State; // Version number read from remote host just before connection completes - int RemoteMajorVersion; + int RemoteMajorVersion; // RPC version int RemoteMinorVersion; int RemotePatchVersion; + SDKVersion RemoteCodeVersion; }; @@ -192,42 +222,44 @@ class NetworkConnection : public Connection { protected: NetworkConnection() - { - } + { + } virtual ~NetworkConnection() { } public: - virtual void SetState(EConnectionState s) + // Thread-safe interface to set or wait on a connection state change. + // All modifications of the connection state should go through this function, + // on the client side. + void SetState(EConnectionState s) { + Mutex::Locker locker(&StateMutex); + if (s != State) { - Mutex::Locker locker(&StateMutex); + State = s; - if (s != State) + if (State != Client_Connecting && + State != Client_ConnectedWait) { - State = s; - - if (State != Client_Connecting) - { - ConnectingWait.NotifyAll(); - } + ConnectingWait.NotifyAll(); } } } + // Call this function to wait for the state to change to a connected state. void WaitOnConnecting() { Mutex::Locker locker(&StateMutex); - while (State == Client_Connecting) + while (State == Client_Connecting || State == Client_ConnectedWait) { ConnectingWait.Wait(&StateMutex); } } - SockAddr Address; + SockAddr Address; Mutex StateMutex; WaitCondition ConnectingWait; }; @@ -246,7 +278,7 @@ public: } public: - Ptr<TCPSocket> pSocket; + Ptr<TCPSocket> pSocket; }; @@ -255,7 +287,7 @@ public: class PacketizedTCPConnection : public TCPConnection { public: - PacketizedTCPConnection() + PacketizedTCPConnection() { Transport = TransportType_PacketizedTCP; } @@ -284,16 +316,16 @@ public: class BerkleyListenerDescription : public ListenerDescription { public: - static const int DefaultMaxIncomingConnections = 64; - static const int DefaultMaxConnections = 128; + static const int DefaultMaxIncomingConnections = 64; + static const int DefaultMaxConnections = 128; - BerkleyListenerDescription() : - MaxIncomingConnections(DefaultMaxIncomingConnections), - MaxConnections(DefaultMaxConnections) - { - } + BerkleyListenerDescription() : + MaxIncomingConnections(DefaultMaxIncomingConnections), + MaxConnections(DefaultMaxConnections) + { + } - Ptr<BerkleySocket> BoundSocketToListenWith; + Ptr<BerkleySocket> BoundSocketToListenWith; int MaxIncomingConnections; int MaxConnections; }; @@ -303,9 +335,9 @@ public: // Receive payload struct ReceivePayload { - Connection* pConnection; // Source connection - uint8_t* pData; // Pointer to data received - int Bytes; // Number of bytes of data received + Connection* pConnection; // Source connection + uint8_t* pData; // Pointer to data received + int Bytes; // Number of bytes of data received }; //----------------------------------------------------------------------------- @@ -335,22 +367,22 @@ public: class SendParameters { public: - SendParameters() : - pData(NULL), - Bytes(0) - { - } - SendParameters(Ptr<Connection> _pConnection, const void* _pData, int _bytes) : - pConnection(_pConnection), - pData(_pData), - Bytes(_bytes) - { - } + SendParameters() : + pData(NULL), + Bytes(0) + { + } + SendParameters(Ptr<Connection> _pConnection, const void* _pData, int _bytes) : + pConnection(_pConnection), + pData(_pData), + Bytes(_bytes) + { + } public: - Ptr<Connection> pConnection; // Connection to use - const void* pData; // Pointer to data to send - int Bytes; // Number of bytes of data received + Ptr<Connection> pConnection; // Connection to use + const void* pData; // Pointer to data to send + int Bytes; // Number of bytes of data received }; @@ -359,18 +391,18 @@ public: struct ConnectParameters { public: - ConnectParameters() : - Transport(TransportType_None) - { - } + ConnectParameters() : + Transport(TransportType_None) + { + } - TransportType Transport; + TransportType Transport; }; struct ConnectParametersBerkleySocket : public ConnectParameters { - SockAddr RemoteAddress; - Ptr<BerkleySocket> BoundSocketToConnectWith; + SockAddr RemoteAddress; + Ptr<BerkleySocket> BoundSocketToConnectWith; bool Blocking; ConnectParametersBerkleySocket(BerkleySocket* s, SockAddr* addr, bool blocking, @@ -388,11 +420,11 @@ struct ConnectParametersBerkleySocket : public ConnectParameters // Listener receive result enum ListenerReceiveResult { - /// The SessionListener used this message and it shouldn't be given to the user. - LRR_RETURN = 0, + /// The SessionListener used this message and it shouldn't be given to the user. + LRR_RETURN = 0, - /// The SessionListener is going to hold on to this message. Do not deallocate it but do not pass it to other plugins either. - LRR_BREAK, + /// The SessionListener is going to hold on to this message. Do not deallocate it but do not pass it to other plugins either. + LRR_BREAK, /// This message will be processed by other SessionListeners, and at last by the user. LRR_CONTINUE, @@ -406,15 +438,15 @@ enum ListenerReceiveResult class SessionListener { public: - virtual ~SessionListener(){} + virtual ~SessionListener(){} - // Data events + // Data events virtual void OnReceive(ReceivePayload* pPayload, ListenerReceiveResult* lrrOut) { OVR_UNUSED2(pPayload, lrrOut); } - // Connection was closed + // Connection was closed virtual void OnDisconnected(Connection* conn) = 0; - // Connection was created (some data was exchanged to verify protocol compatibility too) + // Connection was created (some data was exchanged to verify protocol compatibility too) virtual void OnConnected(Connection* conn) = 0; // Server accepted client @@ -430,7 +462,7 @@ public: // Disconnected during initial handshake virtual void OnHandshakeAttemptFailed(Connection* conn) { OnConnectionAttemptFailed(conn); } - // Other + // Other virtual void OnAddedToSession(Session* session) { OVR_UNUSED(session); } virtual void OnRemovedFromSession(Session* session) { OVR_UNUSED(session); } }; @@ -442,32 +474,27 @@ public: // Interface for network events such as listening on a socket, sending data, connecting, and disconnecting. Works independently of the transport medium and also implements loopback class Session : public SocketEvent_TCP, public NewOverrideBase { - // Implement a policy to avoid releasing memory backing allBlockingTcpSockets - struct ArrayNoShrinkPolicy : ArrayDefaultPolicy - { - bool NeverShrinking() const { return 1; } - }; - public: Session() : - HasLoopbackListener(false) + HaveFullConnections(false) { } virtual ~Session() { // Ensure memory backing the sockets array is released - allBlockingTcpSockets.ClearAndRelease(); + AllBlockingTcpSockets.ClearAndRelease(); } - virtual SessionResult Listen(ListenerDescription* pListenerDescription); - virtual SessionResult Connect(ConnectParameters* cp); - virtual int Send(SendParameters* payload); + virtual SessionResult Listen(ListenerDescription* pListenerDescription); + virtual SessionResult Connect(ConnectParameters* cp); + virtual int Send(SendParameters* payload); virtual void Broadcast(BroadcastParameters* payload); - // DO NOT CALL Poll() FROM MULTIPLE THREADS due to allBlockingTcpSockets being a member + // DO NOT CALL Poll() FROM MULTIPLE THREADS due to AllBlockingTcpSockets being a member virtual void Poll(bool listeners = true); - virtual void AddSessionListener(SessionListener* se); - virtual void RemoveSessionListener(SessionListener* se); - virtual SInt32 GetActiveSocketsCount(); + virtual void AddSessionListener(SessionListener* se); + virtual void RemoveSessionListener(SessionListener* se); + // GetActiveSocketsCount() is not thread-safe: Socket count may change at any time. + virtual int GetActiveSocketsCount(); // Packetized TCP convenience functions virtual SessionResult ListenPTCP(BerkleyBindParameters* bbp); @@ -476,7 +503,15 @@ public: // Closes all the sockets; useful for interrupting the socket polling during shutdown void Shutdown(); + // Returns true if there is at least one successful connection + // WARNING: This function may not be in sync across threads, but it IS atomic + bool ConnectionSuccessful() const + { + return HaveFullConnections.load(std::memory_order_relaxed); + } + // Get count of successful connections (past handshake point) + // WARNING: This function is not thread-safe int GetConnectionCount() const { return FullConnections.GetSizeI(); @@ -484,15 +519,16 @@ public: Ptr<Connection> GetConnectionAtIndex(int index); protected: - virtual Ptr<Connection> AllocConnection(TransportType transportType); + virtual Ptr<Connection> AllocConnection(TransportType transportType); Lock SocketListenersLock, ConnectionsLock, SessionListenersLock; - bool HasLoopbackListener; // Has loopback listener installed? - Array< Ptr<TCPSocket> > SocketListeners; // List of active sockets + Array< Ptr<TCPSocket> > SocketListeners; // List of active sockets Array< Ptr<Connection> > AllConnections; // List of active connections stuck at the versioning handshake Array< Ptr<Connection> > FullConnections; // List of active connections past the versioning handshake Array< SessionListener* > SessionListeners; // List of session listeners - Array< Ptr< Net::TCPSocket >, ArrayNoShrinkPolicy > allBlockingTcpSockets; // Preallocated blocking sockets array + Array< Ptr< Net::TCPSocket > > AllBlockingTcpSockets; // Preallocated blocking sockets array + + std::atomic<bool> HaveFullConnections; // Tools Ptr<PacketizedTCPConnection> findConnectionBySocket(Array< Ptr<Connection> >& connectionArray, Socket* s, int *connectionIndex = NULL); // Call with ConnectionsLock held @@ -500,11 +536,18 @@ protected: int invokeSessionListeners(ReceivePayload*); void invokeSessionEvent(void(SessionListener::*f)(Connection*), Connection* pConnection); - // TCP - virtual void TCP_OnRecv(Socket* pSocket, uint8_t* pData, int bytesRead); - virtual void TCP_OnClosed(TCPSocket* pSocket); - virtual void TCP_OnAccept(TCPSocket* pListener, SockAddr* pSockAddr, SocketHandle newSock); - virtual void TCP_OnConnected(TCPSocket* pSocket); + // TCP + virtual void TCP_OnRecv(Socket* pSocket, uint8_t* pData, int bytesRead); + virtual void TCP_OnClosed(TCPSocket* pSocket); + virtual void TCP_OnAccept(TCPSocket* pListener, SockAddr* pSockAddr, SocketHandle newSock); + virtual void TCP_OnConnected(TCPSocket* pSocket); + +public: + static void SetSingleProcess(bool enable); + static bool IsSingleProcess(); + +protected: + Session* SingleTargetSession; // Target for SingleProcess mode }; diff --git a/LibOVR/Src/Net/OVR_Socket.h b/LibOVR/Src/Net/OVR_Socket.h index b02e038..df6407f 100644 --- a/LibOVR/Src/Net/OVR_Socket.h +++ b/LibOVR/Src/Net/OVR_Socket.h @@ -28,18 +28,17 @@ limitations under the License. #ifndef OVR_Socket_h #define OVR_Socket_h -#include "../Kernel/OVR_Types.h" -#include "../Kernel/OVR_Timer.h" -#include "../Kernel/OVR_Allocator.h" -#include "../Kernel/OVR_RefCount.h" -#include "../Kernel/OVR_String.h" +#include "Kernel/OVR_Types.h" +#include "Kernel/OVR_Timer.h" +#include "Kernel/OVR_Allocator.h" +#include "Kernel/OVR_RefCount.h" +#include "Kernel/OVR_String.h" // OS-specific socket headers #if defined(OVR_OS_WIN32) #include <WinSock2.h> #include <WS2tcpip.h> -#define WIN32_LEAN_AND_MEAN -#include <windows.h> +#include "Kernel/OVR_Win32_IncludeWindows.h" #else # include <unistd.h> # include <sys/types.h> @@ -72,7 +71,6 @@ static const int SOCKET_ERROR = -1; enum TransportType { TransportType_None, // No transport (useful placeholder for invalid states) - TransportType_Loopback, // Loopback transport: Class talks to itself TransportType_TCP, // TCP/IPv4/v6 TransportType_UDP, // UDP/IPv4/v6 TransportType_PacketizedTCP // Packetized TCP: Message framing is automatic diff --git a/LibOVR/Src/Net/OVR_Unix_Socket.cpp b/LibOVR/Src/Net/OVR_Unix_Socket.cpp index 6f2a678..7477ee7 100644 --- a/LibOVR/Src/Net/OVR_Unix_Socket.cpp +++ b/LibOVR/Src/Net/OVR_Unix_Socket.cpp @@ -25,10 +25,10 @@ limitations under the License. ************************************************************************************/ #include "OVR_Unix_Socket.h" -#include "../Kernel/OVR_Std.h" -#include "../Kernel/OVR_Allocator.h" -#include "../Kernel/OVR_Threads.h" // Thread::MSleep -#include "../Kernel/OVR_Log.h" +#include "Kernel/OVR_Std.h" +#include "Kernel/OVR_Allocator.h" +#include "Kernel/OVR_Threads.h" // Thread::MSleep +#include "Kernel/OVR_Log.h" #include <errno.h> @@ -409,10 +409,14 @@ TCPSocket::TCPSocket(SocketHandle boundHandle, bool isListenSocket) TheSocket = boundHandle; IsListenSocket = isListenSocket; IsConnecting = false; - SetSocketOptions(TheSocket); - // The actual socket is always non-blocking - _Ioctlsocket(TheSocket, 1); + if (TheSocket != INVALID_SOCKET) + { + SetSocketOptions(TheSocket); + + // The actual socket is always non-blocking + _Ioctlsocket(TheSocket, 1); + } } TCPSocket::~TCPSocket() diff --git a/LibOVR/Src/Net/OVR_Win32_Socket.cpp b/LibOVR/Src/Net/OVR_Win32_Socket.cpp index edc6ade..3cd2ada 100644 --- a/LibOVR/Src/Net/OVR_Win32_Socket.cpp +++ b/LibOVR/Src/Net/OVR_Win32_Socket.cpp @@ -1,608 +1,618 @@ -/************************************************************************************
-
-Filename : OVR_Win32_Socket.cpp
-Content : Windows-specific socket-based networking implementation
-Created : June 10, 2014
-Authors : Kevin Jenkins
-
-Copyright : Copyright 2014 Oculus VR, LLC All Rights reserved.
-
-Licensed under the Oculus VR Rift SDK License Version 3.2 (the "License");
-you may not use the Oculus VR Rift SDK except in compliance with the License,
-which is provided at the time of installation or download, or which
-otherwise accompanies this software in either electronic or hard copy form.
-
-You may obtain a copy of the License at
-
-http://www.oculusvr.com/licenses/LICENSE-3.2
-
-Unless required by applicable law or agreed to in writing, the Oculus VR SDK
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-
-************************************************************************************/
-
-#include "OVR_Win32_Socket.h"
-#include "../Kernel/OVR_Std.h"
-#include "../Kernel/OVR_Allocator.h"
-#include "../Kernel/OVR_Threads.h" // Thread::MSleep
-#include "../Kernel/OVR_Log.h"
-
-#include <Winsock2.h>
-
-namespace OVR { namespace Net {
-
-
-//-----------------------------------------------------------------------------
-// WSAStartupSingleton
-
-class WSAStartupSingleton
-{
-public:
- static void AddRef(void);
- static void Deref(void);
-
-protected:
- static int RefCount;
-};
-
-
-// Local data
-int WSAStartupSingleton::RefCount = 0;
-
-
-// Implementation
-void WSAStartupSingleton::AddRef()
-{
- if (++RefCount == 1)
- {
- WSADATA winsockInfo;
- const int errCode = WSAStartup(MAKEWORD(2, 2), &winsockInfo);
- OVR_ASSERT(errCode == 0);
-
- // If an error code is returned
- if (errCode != 0)
- {
- LogError("{ERR-007w} [Socket] Unable to initialize Winsock %d", errCode);
- }
- }
-}
-
-void WSAStartupSingleton::Deref()
-{
- OVR_ASSERT(RefCount > 0);
-
- if (RefCount > 0)
- {
- if (--RefCount == 0)
- {
- WSACleanup();
- RefCount = 0;
- }
- }
-}
-
-
-//-----------------------------------------------------------------------------
-// BerkleySocket
-
-void BerkleySocket::Close()
-{
- if (TheSocket != INVALID_SOCKET)
- {
- closesocket(TheSocket);
- TheSocket = INVALID_SOCKET;
- }
-}
-
-int32_t BerkleySocket::GetSockname(SockAddr *pSockAddrOut)
-{
- struct sockaddr_in6 sa;
- memset(&sa,0,sizeof(sa));
- int size = sizeof(sa);
- int32_t i = getsockname(TheSocket, (sockaddr*) &sa, &size);
- if (i>=0)
- {
- pSockAddrOut->Set(&sa);
- }
- return i;
-}
-
-
-//-----------------------------------------------------------------------------
-// BitStream overloads for SockAddr
-
-BitStream& operator<<(BitStream& out, SockAddr& in)
-{
- out.WriteBits((const unsigned char*) &in.Addr6, sizeof(in.Addr6)*8, true);
- return out;
-}
-
-BitStream& operator>>(BitStream& in, SockAddr& out)
-{
- bool success = in.ReadBits((unsigned char*) &out.Addr6, sizeof(out.Addr6)*8, true);
- OVR_ASSERT(success);
- OVR_UNUSED(success);
- return in;
-}
-
-
-//-----------------------------------------------------------------------------
-// SockAddr
-
-SockAddr::SockAddr()
-{
- WSAStartupSingleton::AddRef();
-
- // Zero out the address to squelch static analysis tools
- ZeroMemory(&Addr6, sizeof(Addr6));
-}
-
-SockAddr::SockAddr(SockAddr* address)
-{
- WSAStartupSingleton::AddRef();
- Set(&address->Addr6);
-}
-
-SockAddr::SockAddr(sockaddr_storage* storage)
-{
- WSAStartupSingleton::AddRef();
- Set(storage);
-}
-
-SockAddr::SockAddr(sockaddr_in6* address)
-{
- WSAStartupSingleton::AddRef();
- Set(address);
-}
-
-SockAddr::SockAddr(const char* hostAddress, uint16_t port, int sockType)
-{
- WSAStartupSingleton::AddRef();
- Set(hostAddress, port, sockType);
-}
-
-void SockAddr::Set(const sockaddr_storage* storage)
-{
- memcpy(&Addr6, storage, sizeof(Addr6));
-}
-
-void SockAddr::Set(const sockaddr_in6* address)
-{
- memcpy(&Addr6, address, sizeof(Addr6));
-}
-
-void SockAddr::Set(const char* hostAddress, uint16_t port, int sockType)
-{
- memset(&Addr6, 0, sizeof(Addr6));
-
- struct addrinfo hints;
-
- // make sure the struct is empty
- memset(&hints, 0, sizeof (addrinfo));
-
- hints.ai_socktype = sockType; // SOCK_DGRAM or SOCK_STREAM
- hints.ai_flags = AI_PASSIVE; // fill in my IP for me
- hints.ai_family = AF_UNSPEC ;
-
- // FIXME See OVR_Unix_Socket implementation and man pages for getaddrinfo.
- // ai_protocol is expecting to be either IPPROTO_UDP and IPPROTO_TCP.
- // But this has been working on windows so I'm leaving it be for
- // now instead of introducing another variable.
- hints.ai_protocol = IPPROTO_IPV6;
-
- struct addrinfo* servinfo = NULL; // will point to the results
-
- char portStr[32];
- OVR_itoa(port, portStr, sizeof(portStr), 10);
- int errcode = getaddrinfo(hostAddress, portStr, &hints, &servinfo);
-
- if (0 != errcode)
- {
- OVR::LogError("{ERR-008w} getaddrinfo error: %s", gai_strerror(errcode));
- }
-
- OVR_ASSERT(servinfo);
-
- if (servinfo)
- {
- memcpy(&Addr6, servinfo->ai_addr, sizeof(Addr6));
-
- freeaddrinfo(servinfo);
- }
-}
-
-uint16_t SockAddr::GetPort()
-{
- return htons(Addr6.sin6_port);
-}
-
-String SockAddr::ToString(bool writePort, char portDelineator) const
-{
- char dest[INET6_ADDRSTRLEN + 1];
-
- int ret = getnameinfo((struct sockaddr*)&Addr6,
- sizeof(struct sockaddr_in6),
- dest,
- INET6_ADDRSTRLEN,
- NULL,
- 0,
- NI_NUMERICHOST);
- if (ret != 0)
- {
- dest[0] = '\0';
- }
-
- if (writePort)
- {
- unsigned char ch[2];
- ch[0]=portDelineator;
- ch[1]=0;
- OVR_strcat(dest, 16, (const char*) ch);
- OVR_itoa(ntohs(Addr6.sin6_port), dest+strlen(dest), 16, 10);
- }
-
- return String(dest);
-}
-bool SockAddr::IsLocalhost() const
-{
- static const unsigned char localhost_bytes[] =
- { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 };
-
- return memcmp(Addr6.sin6_addr.s6_addr, localhost_bytes, 16) == 0;
-}
-bool SockAddr::operator==( const SockAddr& right ) const
-{
- return memcmp(&Addr6, &right.Addr6, sizeof(Addr6)) == 0;
-}
-
-bool SockAddr::operator!=( const SockAddr& right ) const
-{
- return !(*this == right);
-}
-
-bool SockAddr::operator>( const SockAddr& right ) const
-{
- return memcmp(&Addr6, &right.Addr6, sizeof(Addr6)) > 0;
-}
-
-bool SockAddr::operator<( const SockAddr& right ) const
-{
- return memcmp(&Addr6, &right.Addr6, sizeof(Addr6)) < 0;
-}
-
-static bool SetSocketOptions(SocketHandle sock)
-{
- int result = 0;
- int sock_opt;
-
- // This doubles the max throughput rate
- sock_opt = 1024 * 256;
- result |= setsockopt(sock, SOL_SOCKET, SO_RCVBUF, (char *)& sock_opt, sizeof (sock_opt));
-
- // Immediate hard close. Don't linger the socket, or recreating the socket quickly on Vista fails.
- sock_opt = 0;
- result |= setsockopt(sock, SOL_SOCKET, SO_LINGER, (char *)& sock_opt, sizeof (sock_opt));
-
- // This doesn't make much difference: 10% maybe
- sock_opt = 1024 * 16;
- result |= setsockopt(sock, SOL_SOCKET, SO_SNDBUF, (char *)& sock_opt, sizeof (sock_opt));
-
- // If all the setsockopt() returned 0 there were no failures, so return true for success, else false
- return result == 0;
-}
-
-void _Ioctlsocket(SocketHandle sock, unsigned long nonblocking)
-{
- ioctlsocket(sock, FIONBIO, &nonblocking);
-}
-
-static SocketHandle BindShared(int ai_family, int ai_socktype, BerkleyBindParameters* pBindParameters)
-{
- SocketHandle sock;
-
- struct addrinfo hints;
- memset(&hints, 0, sizeof (addrinfo)); // make sure the struct is empty
- hints.ai_family = ai_family;
- hints.ai_socktype = ai_socktype;
- hints.ai_flags = AI_PASSIVE; // fill in my IP for me
- struct addrinfo *servinfo=0, *aip; // will point to the results
- char portStr[32];
- OVR_itoa(pBindParameters->Port, portStr, sizeof(portStr), 10);
-
- int errcode = 0;
- if (!pBindParameters->Address.IsEmpty())
- errcode = getaddrinfo(pBindParameters->Address.ToCStr(), portStr, &hints, &servinfo);
- else
- errcode = getaddrinfo(0, portStr, &hints, &servinfo);
-
- if (0 != errcode)
- {
- OVR::LogError("{ERR-020w} getaddrinfo error: %s", gai_strerror(errcode));
- }
-
- for (aip = servinfo; aip != NULL; aip = aip->ai_next)
- {
- // Open socket. The address type depends on what
- // getaddrinfo() gave us.
- sock = socket(aip->ai_family, aip->ai_socktype, aip->ai_protocol);
- if (sock != INVALID_SOCKET)
- {
- if (bind(sock, aip->ai_addr, (int)aip->ai_addrlen) != SOCKET_ERROR)
- {
- // The actual socket is always non-blocking
- // I control blocking or not using WSAEventSelect
- _Ioctlsocket(sock, 1);
- freeaddrinfo(servinfo);
- return sock;
- }
-
- closesocket(sock);
- }
- }
-
- if (servinfo) { freeaddrinfo(servinfo); }
- return INVALID_SOCKET;
-}
-
-
-//-----------------------------------------------------------------------------
-// UDPSocket
-
-UDPSocket::UDPSocket()
-{
- WSAStartupSingleton::AddRef();
- RecvBuf = new uint8_t[RecvBufSize];
-}
-
-UDPSocket::~UDPSocket()
-{
- WSAStartupSingleton::Deref();
- delete[] RecvBuf;
-}
-
-SocketHandle UDPSocket::Bind(BerkleyBindParameters *pBindParameters)
-{
- SocketHandle s = BindShared(AF_INET6, SOCK_DGRAM, pBindParameters);
- if (s == INVALID_SOCKET)
- return s;
-
- Close();
- TheSocket = s;
- SetSocketOptions(TheSocket);
-
- return TheSocket;
-}
-
-void UDPSocket::OnRecv(SocketEvent_UDP* eventHandler, uint8_t* pData, int bytesRead, SockAddr* address)
-{
- eventHandler->UDP_OnRecv(this, pData, bytesRead, address);
-}
-
-int UDPSocket::Send(const void* pData, int bytes, SockAddr* address)
-{
- return sendto(TheSocket, (const char*)pData, bytes, 0, (const sockaddr*)&address->Addr6, sizeof(address->Addr6));
-}
-
-void UDPSocket::Poll(SocketEvent_UDP *eventHandler)
-{
- struct sockaddr_storage win32_addr;
- socklen_t fromlen;
- int bytesRead;
-
- // FIXME: Implement blocking poll wait for UDP
-
- // While some bytes are read,
- while (fromlen = sizeof(win32_addr), // Must set fromlen each time
- bytesRead = recvfrom(TheSocket, (char*)RecvBuf, RecvBufSize, 0, (sockaddr*)&win32_addr, &fromlen),
- bytesRead > 0)
- {
- SockAddr address(&win32_addr); // Wrap address
-
- OnRecv(eventHandler, RecvBuf, bytesRead, &address);
- }
-}
-
-
-//-----------------------------------------------------------------------------
-// TCPSocket
-
-TCPSocket::TCPSocket()
-{
- IsConnecting = false;
- IsListenSocket = false;
- WSAStartupSingleton::AddRef();
-}
-TCPSocket::TCPSocket(SocketHandle boundHandle, bool isListenSocket)
-{
- TheSocket = boundHandle;
- IsListenSocket = isListenSocket;
- IsConnecting = false;
- WSAStartupSingleton::AddRef();
- SetSocketOptions(TheSocket);
-
- // The actual socket is always non-blocking
- _Ioctlsocket(TheSocket, 1);
-}
-
-TCPSocket::~TCPSocket()
-{
- WSAStartupSingleton::Deref();
-}
-
-void TCPSocket::OnRecv(SocketEvent_TCP* eventHandler, uint8_t* pData, int bytesRead)
-{
- eventHandler->TCP_OnRecv(this, pData, bytesRead);
-}
-
-SocketHandle TCPSocket::Bind(BerkleyBindParameters* pBindParameters)
-{
- SocketHandle s = BindShared(AF_INET6, SOCK_STREAM, pBindParameters);
- if (s == INVALID_SOCKET)
- return s;
-
- Close();
-
- SetBlockingTimeout(pBindParameters->blockingTimeout);
- TheSocket = s;
-
- SetSocketOptions(TheSocket);
-
- return TheSocket;
-}
-
-int TCPSocket::Listen()
-{
- if (IsListenSocket)
- {
- return 0;
- }
-
- int i = listen(TheSocket, SOMAXCONN);
- if (i >= 0)
- {
- IsListenSocket = true;
- }
-
- return i;
-}
-
-int TCPSocket::Connect(SockAddr* address)
-{
- int retval;
-
- retval = connect(TheSocket, (struct sockaddr *) &address->Addr6, sizeof(address->Addr6));
- if (retval < 0)
- {
- DWORD dwIOError = WSAGetLastError();
- if (dwIOError == WSAEWOULDBLOCK)
- {
- IsConnecting = true;
- return 0;
- }
-
- printf( "TCPSocket::Connect failed:Error code - %d\n", dwIOError );
- }
-
- return retval;
-}
-
-int TCPSocket::Send(const void* pData, int bytes)
-{
- if (bytes <= 0)
- {
- return 0;
- }
- else
- {
- return send(TheSocket, (const char*)pData, bytes, 0);
- }
-}
-
-
-//// TCPSocketPollState
-
-TCPSocketPollState::TCPSocketPollState()
-{
- FD_ZERO(&readFD);
- FD_ZERO(&exceptionFD);
- FD_ZERO(&writeFD);
- largestDescriptor = INVALID_SOCKET;
-}
-
-bool TCPSocketPollState::IsValid() const
-{
- return largestDescriptor != INVALID_SOCKET;
-}
-
-void TCPSocketPollState::Add(TCPSocket* tcpSocket)
-{
- if (!tcpSocket)
- {
- return;
- }
-
- SocketHandle handle = tcpSocket->GetSocketHandle();
-
- if (largestDescriptor == INVALID_SOCKET ||
- largestDescriptor < handle)
- {
- largestDescriptor = handle;
- }
-
- FD_SET(handle, &readFD);
- FD_SET(handle, &exceptionFD);
-
- if (tcpSocket->IsConnecting)
- {
- FD_SET(handle, &writeFD);
- }
-}
-
-bool TCPSocketPollState::Poll(long usec, long seconds)
-{
- timeval tv;
- tv.tv_sec = seconds;
- tv.tv_usec = usec;
-
- return (int)select((int)largestDescriptor + 1, &readFD, &writeFD, &exceptionFD, &tv) > 0;
-}
-
-void TCPSocketPollState::HandleEvent(TCPSocket* tcpSocket, SocketEvent_TCP* eventHandler)
-{
- if (!tcpSocket || !eventHandler)
- {
- return;
- }
-
- SocketHandle handle = tcpSocket->GetSocketHandle();
-
- if (tcpSocket->IsConnecting && FD_ISSET(handle, &writeFD))
- {
- tcpSocket->IsConnecting = false;
- eventHandler->TCP_OnConnected(tcpSocket);
- }
-
- if (FD_ISSET(handle, &readFD))
- {
- if (!tcpSocket->IsListenSocket)
- {
- static const int BUFF_SIZE = 8096;
- char data[BUFF_SIZE];
-
- int bytesRead = recv(handle, data, BUFF_SIZE, 0);
- if (bytesRead > 0)
- {
- tcpSocket->OnRecv(eventHandler, (uint8_t*)data, bytesRead);
- }
- else // Disconnection event:
- {
- tcpSocket->IsConnecting = false;
- eventHandler->TCP_OnClosed(tcpSocket);
- }
- }
- else
- {
- struct sockaddr_storage sockAddr;
- socklen_t sockAddrSize = sizeof(sockAddr);
-
- SocketHandle newSock = accept(handle, (sockaddr*)&sockAddr, (socklen_t*)&sockAddrSize);
- if (newSock != INVALID_SOCKET)
- {
- SockAddr sa(&sockAddr);
- eventHandler->TCP_OnAccept(tcpSocket, &sa, newSock);
- }
- }
- }
-
- if (FD_ISSET(handle, &exceptionFD))
- {
- tcpSocket->IsConnecting = false;
- eventHandler->TCP_OnClosed(tcpSocket);
- }
-}
-
-
-}} // namespace OVR::Net
+/************************************************************************************ + +Filename : OVR_Win32_Socket.cpp +Content : Windows-specific socket-based networking implementation +Created : June 10, 2014 +Authors : Kevin Jenkins + +Copyright : Copyright 2014 Oculus VR, LLC All Rights reserved. + +Licensed under the Oculus VR Rift SDK License Version 3.2 (the "License"); +you may not use the Oculus VR Rift SDK except in compliance with the License, +which is provided at the time of installation or download, or which +otherwise accompanies this software in either electronic or hard copy form. + +You may obtain a copy of the License at + +http://www.oculusvr.com/licenses/LICENSE-3.2 + +Unless required by applicable law or agreed to in writing, the Oculus VR SDK +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +************************************************************************************/ + +#include "OVR_Win32_Socket.h" +#include "Kernel/OVR_Std.h" +#include "Kernel/OVR_Allocator.h" +#include "Kernel/OVR_Threads.h" // Thread::MSleep +#include "Kernel/OVR_Log.h" + +#include <Winsock2.h> +#pragma comment(lib, "ws2_32.lib") + +namespace OVR { namespace Net { + + +//----------------------------------------------------------------------------- +// WSAStartupSingleton + +class WSAStartupSingleton +{ +public: + static void AddRef(void); + static void Deref(void); + +protected: + static int RefCount; +}; + + +// Local data +int WSAStartupSingleton::RefCount = 0; + + +// Implementation +void WSAStartupSingleton::AddRef() +{ + if (++RefCount == 1) + { + WSADATA winsockInfo; + const int errCode = WSAStartup(MAKEWORD(2, 2), &winsockInfo); + OVR_ASSERT(errCode == 0); + + // If an error code is returned + if (errCode != 0) + { + LogError("{ERR-007w} [Socket] Unable to initialize Winsock %d", errCode); + } + } +} + +void WSAStartupSingleton::Deref() +{ + OVR_ASSERT(RefCount > 0); + + if (RefCount > 0) + { + if (--RefCount == 0) + { + WSACleanup(); + RefCount = 0; + } + } +} + + +//----------------------------------------------------------------------------- +// BerkleySocket + +void BerkleySocket::Close() +{ + if (TheSocket != INVALID_SOCKET) + { + closesocket(TheSocket); + TheSocket = INVALID_SOCKET; + } +} + +int32_t BerkleySocket::GetSockname(SockAddr *pSockAddrOut) +{ + struct sockaddr_in6 sa; + memset(&sa,0,sizeof(sa)); + int size = sizeof(sa); + int32_t i = getsockname(TheSocket, (sockaddr*) &sa, &size); + if (i>=0) + { + pSockAddrOut->Set(&sa); + } + return i; +} + + +//----------------------------------------------------------------------------- +// BitStream overloads for SockAddr + +BitStream& operator<<(BitStream& out, SockAddr& in) +{ + out.WriteBits((const unsigned char*) &in.Addr6, sizeof(in.Addr6)*8, true); + return out; +} + +BitStream& operator>>(BitStream& in, SockAddr& out) +{ + bool success = in.ReadBits((unsigned char*) &out.Addr6, sizeof(out.Addr6)*8, true); + OVR_ASSERT(success); + OVR_UNUSED(success); + return in; +} + + +//----------------------------------------------------------------------------- +// SockAddr + +SockAddr::SockAddr() +{ + WSAStartupSingleton::AddRef(); + + // Zero out the address to squelch static analysis tools + ZeroMemory(&Addr6, sizeof(Addr6)); +} + +SockAddr::SockAddr(SockAddr* address) +{ + WSAStartupSingleton::AddRef(); + Set(&address->Addr6); +} + +SockAddr::SockAddr(sockaddr_storage* storage) +{ + WSAStartupSingleton::AddRef(); + Set(storage); +} + +SockAddr::SockAddr(sockaddr_in6* address) +{ + WSAStartupSingleton::AddRef(); + Set(address); +} + +SockAddr::SockAddr(const char* hostAddress, uint16_t port, int sockType) +{ + WSAStartupSingleton::AddRef(); + Set(hostAddress, port, sockType); +} + +void SockAddr::Set(const sockaddr_storage* storage) +{ + memcpy(&Addr6, storage, sizeof(Addr6)); +} + +void SockAddr::Set(const sockaddr_in6* address) +{ + memcpy(&Addr6, address, sizeof(Addr6)); +} + +void SockAddr::Set(const char* hostAddress, uint16_t port, int sockType) +{ + memset(&Addr6, 0, sizeof(Addr6)); + + struct addrinfo hints; + + // make sure the struct is empty + memset(&hints, 0, sizeof (addrinfo)); + + hints.ai_socktype = sockType; // SOCK_DGRAM or SOCK_STREAM + hints.ai_flags = AI_PASSIVE; // fill in my IP for me + hints.ai_family = AF_UNSPEC ; + + // FIXME See OVR_Unix_Socket implementation and man pages for getaddrinfo. + // ai_protocol is expecting to be either IPPROTO_UDP and IPPROTO_TCP. + // But this has been working on windows so I'm leaving it be for + // now instead of introducing another variable. + hints.ai_protocol = IPPROTO_IPV6; + + struct addrinfo* servinfo = NULL; // will point to the results + + char portStr[32]; + OVR_itoa(port, portStr, sizeof(portStr), 10); + int errcode = getaddrinfo(hostAddress, portStr, &hints, &servinfo); + + if (0 != errcode) + { + OVR::LogError("{ERR-008w} getaddrinfo error: %s", gai_strerror(errcode)); + } + + OVR_ASSERT(servinfo); + + if (servinfo) + { + memcpy(&Addr6, servinfo->ai_addr, sizeof(Addr6)); + + freeaddrinfo(servinfo); + } +} + +uint16_t SockAddr::GetPort() +{ + return htons(Addr6.sin6_port); +} + +String SockAddr::ToString(bool writePort, char portDelineator) const +{ + char dest[INET6_ADDRSTRLEN + 1]; + + int ret = getnameinfo((struct sockaddr*)&Addr6, + sizeof(struct sockaddr_in6), + dest, + INET6_ADDRSTRLEN, + NULL, + 0, + NI_NUMERICHOST); + if (ret != 0) + { + dest[0] = '\0'; + } + + if (writePort) + { + unsigned char ch[2]; + ch[0]=portDelineator; + ch[1]=0; + OVR_strcat(dest, 16, (const char*) ch); + OVR_itoa(ntohs(Addr6.sin6_port), dest+strlen(dest), 16, 10); + } + + return String(dest); +} +bool SockAddr::IsLocalhost() const +{ + static const unsigned char localhost_bytes[] = + { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }; + + return memcmp(Addr6.sin6_addr.s6_addr, localhost_bytes, 16) == 0; +} +bool SockAddr::operator==( const SockAddr& right ) const +{ + return memcmp(&Addr6, &right.Addr6, sizeof(Addr6)) == 0; +} + +bool SockAddr::operator!=( const SockAddr& right ) const +{ + return !(*this == right); +} + +bool SockAddr::operator>( const SockAddr& right ) const +{ + return memcmp(&Addr6, &right.Addr6, sizeof(Addr6)) > 0; +} + +bool SockAddr::operator<( const SockAddr& right ) const +{ + return memcmp(&Addr6, &right.Addr6, sizeof(Addr6)) < 0; +} + +static bool SetSocketOptions(SocketHandle sock) +{ + int result = 0; + int sock_opt; + + // This doubles the max throughput rate + sock_opt = 1024 * 256; + result |= setsockopt(sock, SOL_SOCKET, SO_RCVBUF, (char *)& sock_opt, sizeof (sock_opt)); + + // Immediate hard close. Don't linger the socket, or recreating the socket quickly on Vista fails. + sock_opt = 0; + result |= setsockopt(sock, SOL_SOCKET, SO_LINGER, (char *)& sock_opt, sizeof (sock_opt)); + + // This doesn't make much difference: 10% maybe + sock_opt = 1024 * 16; + result |= setsockopt(sock, SOL_SOCKET, SO_SNDBUF, (char *)& sock_opt, sizeof (sock_opt)); + + // If all the setsockopt() returned 0 there were no failures, so return true for success, else false + return result == 0; +} + +void _Ioctlsocket(SocketHandle sock, unsigned long nonblocking) +{ + ioctlsocket(sock, FIONBIO, &nonblocking); +} + +static SocketHandle BindShared(int ai_family, int ai_socktype, BerkleyBindParameters* pBindParameters) +{ + SocketHandle sock; + + struct addrinfo hints; + memset(&hints, 0, sizeof (addrinfo)); // make sure the struct is empty + hints.ai_family = ai_family; + hints.ai_socktype = ai_socktype; + hints.ai_flags = AI_PASSIVE; // fill in my IP for me + struct addrinfo *servinfo=0, *aip; // will point to the results + char portStr[32]; + OVR_itoa(pBindParameters->Port, portStr, sizeof(portStr), 10); + + int errcode = 0; + if (!pBindParameters->Address.IsEmpty()) + errcode = getaddrinfo(pBindParameters->Address.ToCStr(), portStr, &hints, &servinfo); + else + errcode = getaddrinfo(0, portStr, &hints, &servinfo); + + if (0 != errcode) + { + OVR::LogError("{ERR-020w} getaddrinfo error: %s", gai_strerror(errcode)); + } + + for (aip = servinfo; aip != NULL; aip = aip->ai_next) + { + // Open socket. The address type depends on what + // getaddrinfo() gave us. + sock = socket(aip->ai_family, aip->ai_socktype, aip->ai_protocol); + if (sock != INVALID_SOCKET) + { + if (bind(sock, aip->ai_addr, (int)aip->ai_addrlen) != SOCKET_ERROR) + { + // The actual socket is always non-blocking + // I control blocking or not using WSAEventSelect + _Ioctlsocket(sock, 1); + freeaddrinfo(servinfo); + return sock; + } + + closesocket(sock); + } + } + + if (servinfo) { freeaddrinfo(servinfo); } + return INVALID_SOCKET; +} + + +//----------------------------------------------------------------------------- +// UDPSocket + +UDPSocket::UDPSocket() +{ + WSAStartupSingleton::AddRef(); + RecvBuf = new uint8_t[RecvBufSize]; +} + +UDPSocket::~UDPSocket() +{ + WSAStartupSingleton::Deref(); + delete[] RecvBuf; +} + +SocketHandle UDPSocket::Bind(BerkleyBindParameters *pBindParameters) +{ + SocketHandle s = BindShared(AF_INET6, SOCK_DGRAM, pBindParameters); + if (s == INVALID_SOCKET) + return s; + + Close(); + TheSocket = s; + SetSocketOptions(TheSocket); + + return TheSocket; +} + +void UDPSocket::OnRecv(SocketEvent_UDP* eventHandler, uint8_t* pData, int bytesRead, SockAddr* address) +{ + eventHandler->UDP_OnRecv(this, pData, bytesRead, address); +} + +int UDPSocket::Send(const void* pData, int bytes, SockAddr* address) +{ + return sendto(TheSocket, (const char*)pData, bytes, 0, (const sockaddr*)&address->Addr6, sizeof(address->Addr6)); +} + +void UDPSocket::Poll(SocketEvent_UDP *eventHandler) +{ + struct sockaddr_storage win32_addr; + socklen_t fromlen; + int bytesRead; + + // FIXME: Implement blocking poll wait for UDP + + // While some bytes are read, + while (fromlen = sizeof(win32_addr), // Must set fromlen each time + bytesRead = recvfrom(TheSocket, (char*)RecvBuf, RecvBufSize, 0, (sockaddr*)&win32_addr, &fromlen), + bytesRead > 0) + { + SockAddr address(&win32_addr); // Wrap address + + OnRecv(eventHandler, RecvBuf, bytesRead, &address); + } +} + + +//----------------------------------------------------------------------------- +// TCPSocket + +TCPSocket::TCPSocket() +{ + IsConnecting = false; + IsListenSocket = false; + WSAStartupSingleton::AddRef(); +} + +TCPSocket::TCPSocket(SocketHandle boundHandle, bool isListenSocket) +{ + TheSocket = boundHandle; + IsListenSocket = isListenSocket; + IsConnecting = false; + WSAStartupSingleton::AddRef(); + + if (TheSocket != INVALID_SOCKET) + { + SetSocketOptions(TheSocket); + + // The actual socket is always non-blocking + _Ioctlsocket(TheSocket, 1); + } +} + +TCPSocket::~TCPSocket() +{ + WSAStartupSingleton::Deref(); +} + +void TCPSocket::OnRecv(SocketEvent_TCP* eventHandler, uint8_t* pData, int bytesRead) +{ + eventHandler->TCP_OnRecv(this, pData, bytesRead); +} + +SocketHandle TCPSocket::Bind(BerkleyBindParameters* pBindParameters) +{ + SocketHandle s = BindShared(AF_INET6, SOCK_STREAM, pBindParameters); + if (s == INVALID_SOCKET) + return s; + + Close(); + + SetBlockingTimeout(pBindParameters->blockingTimeout); + TheSocket = s; + + SetSocketOptions(TheSocket); + + return TheSocket; +} + +int TCPSocket::Listen() +{ + if (IsListenSocket) + { + return 0; + } + + int i = listen(TheSocket, SOMAXCONN); + if (i >= 0) + { + IsListenSocket = true; + } + + return i; +} + +int TCPSocket::Connect(SockAddr* address) +{ + int retval; + + retval = connect(TheSocket, (struct sockaddr *) &address->Addr6, sizeof(address->Addr6)); + if (retval < 0) + { + DWORD dwIOError = WSAGetLastError(); + if (dwIOError == WSAEWOULDBLOCK) + { + IsConnecting = true; + return 0; + } + + LogError("[TCPSocket] ERROR: Connect failed. Error code - %d", (int)dwIOError); + } + + return retval; +} + +int TCPSocket::Send(const void* pData, int bytes) +{ + if (bytes <= 0) + { + return 0; + } + else + { + return send(TheSocket, (const char*)pData, bytes, 0); + } +} + + +//// TCPSocketPollState + +TCPSocketPollState::TCPSocketPollState() +{ + memset(&readFD, 0, sizeof(readFD)); + memset(&exceptionFD, 0, sizeof(exceptionFD)); + memset(&writeFD, 0, sizeof(writeFD)); + + FD_ZERO(&readFD); + FD_ZERO(&exceptionFD); + FD_ZERO(&writeFD); + largestDescriptor = INVALID_SOCKET; +} + +bool TCPSocketPollState::IsValid() const +{ + return largestDescriptor != INVALID_SOCKET; +} + +void TCPSocketPollState::Add(TCPSocket* tcpSocket) +{ + if (!tcpSocket) + { + return; + } + + SocketHandle handle = tcpSocket->GetSocketHandle(); + + if (largestDescriptor == INVALID_SOCKET || + largestDescriptor < handle) + { + largestDescriptor = handle; + } + + FD_SET(handle, &readFD); + FD_SET(handle, &exceptionFD); + + if (tcpSocket->IsConnecting) + { + FD_SET(handle, &writeFD); + } +} + +bool TCPSocketPollState::Poll(long usec, long seconds) +{ + timeval tv; + tv.tv_sec = seconds; + tv.tv_usec = usec; + + return (int)select((int)largestDescriptor + 1, &readFD, &writeFD, &exceptionFD, &tv) > 0; +} + +void TCPSocketPollState::HandleEvent(TCPSocket* tcpSocket, SocketEvent_TCP* eventHandler) +{ + if (!tcpSocket || !eventHandler) + { + return; + } + + SocketHandle handle = tcpSocket->GetSocketHandle(); + + if (tcpSocket->IsConnecting && FD_ISSET(handle, &writeFD)) + { + tcpSocket->IsConnecting = false; + eventHandler->TCP_OnConnected(tcpSocket); + } + + if (FD_ISSET(handle, &readFD)) + { + if (!tcpSocket->IsListenSocket) + { + static const int BUFF_SIZE = 8096; + char data[BUFF_SIZE]; + + int bytesRead = recv(handle, data, BUFF_SIZE, 0); + if (bytesRead > 0) + { + tcpSocket->OnRecv(eventHandler, (uint8_t*)data, bytesRead); + } + else // Disconnection event: + { + tcpSocket->IsConnecting = false; + eventHandler->TCP_OnClosed(tcpSocket); + } + } + else + { + struct sockaddr_storage sockAddr; + socklen_t sockAddrSize = sizeof(sockAddr); + + SocketHandle newSock = accept(handle, (sockaddr*)&sockAddr, (socklen_t*)&sockAddrSize); + if (newSock != INVALID_SOCKET) + { + SockAddr sa(&sockAddr); + eventHandler->TCP_OnAccept(tcpSocket, &sa, newSock); + } + } + } + + if (FD_ISSET(handle, &exceptionFD)) + { + tcpSocket->IsConnecting = false; + eventHandler->TCP_OnClosed(tcpSocket); + } +} + + +}} // namespace OVR::Net diff --git a/LibOVR/Src/Net/OVR_Win32_Socket.h b/LibOVR/Src/Net/OVR_Win32_Socket.h index ac66869..ed0a624 100644 --- a/LibOVR/Src/Net/OVR_Win32_Socket.h +++ b/LibOVR/Src/Net/OVR_Win32_Socket.h @@ -1,151 +1,150 @@ -/************************************************************************************
-
-PublicHeader: n/a
-Filename : OVR_Win32_Socket.h
-Content : Windows-specific socket-based networking implementation
-Created : June 10, 2014
-Authors : Kevin Jenkins
-
-Copyright : Copyright 2014 Oculus VR, LLC All Rights reserved.
-
-Licensed under the Oculus VR Rift SDK License Version 3.2 (the "License");
-you may not use the Oculus VR Rift SDK except in compliance with the License,
-which is provided at the time of installation or download, or which
-otherwise accompanies this software in either electronic or hard copy form.
-
-You may obtain a copy of the License at
-
-http://www.oculusvr.com/licenses/LICENSE-3.2
-
-Unless required by applicable law or agreed to in writing, the Oculus VR SDK
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-
-************************************************************************************/
-
-#ifndef OVR_Win32_Socket_h
-#define OVR_Win32_Socket_h
-
-#include "OVR_Socket.h"
-#include "OVR_BitStream.h"
-
-#include <WinSock2.h>
-#include <WS2tcpip.h>
-#define WIN32_LEAN_AND_MEAN
-#include <Windows.h>
-#include <io.h>
-
-namespace OVR { namespace Net {
-
-
-//-----------------------------------------------------------------------------
-// SockAddr
-
-// Abstraction for IPV6 socket address, with various convenience functions
-class SockAddr
-{
-public:
- SockAddr();
- SockAddr(SockAddr* sa);
- SockAddr(sockaddr_storage* sa);
- SockAddr(sockaddr_in6* sa);
- SockAddr(const char* hostAddress, uint16_t port, int sockType);
-
-public:
- void Set(const sockaddr_storage* sa);
- void Set(const sockaddr_in6* sa);
- void Set(const char* hostAddress, uint16_t port, int sockType); // SOCK_DGRAM or SOCK_STREAM
-
- uint16_t GetPort();
-
- String ToString(bool writePort, char portDelineator) const;
- bool IsLocalhost() const;
-
- void Serialize(BitStream* bs);
- bool Deserialize(BitStream);
-
- bool operator==( const SockAddr& right ) const;
- bool operator!=( const SockAddr& right ) const;
- bool operator >( const SockAddr& right ) const;
- bool operator <( const SockAddr& right ) const;
-
-public:
- sockaddr_in6 Addr6;
-};
-
-
-//-----------------------------------------------------------------------------
-// UDP Socket
-
-// Windows version of TCP socket
-class UDPSocket : public UDPSocketBase
-{
-public:
- UDPSocket();
- virtual ~UDPSocket();
-
-public:
- virtual SocketHandle Bind(BerkleyBindParameters* pBindParameters);
- virtual int Send(const void* pData, int bytes, SockAddr* address);
- virtual void Poll(SocketEvent_UDP* eventHandler);
-
-protected:
- static const int RecvBufSize = 1048576;
- uint8_t* RecvBuf;
-
- virtual void OnRecv(SocketEvent_UDP* eventHandler, uint8_t* pData,
- int bytesRead, SockAddr* address);
-};
-
-
-//-----------------------------------------------------------------------------
-// TCP Socket
-
-// Windows version of TCP socket
-class TCPSocket : public TCPSocketBase
-{
- friend class TCPSocketPollState;
-
-public:
- TCPSocket();
- TCPSocket(SocketHandle boundHandle, bool isListenSocket);
- virtual ~TCPSocket();
-
-public:
- virtual SocketHandle Bind(BerkleyBindParameters* pBindParameters);
- virtual int Listen();
- virtual int Connect(SockAddr* address);
- virtual int Send(const void* pData, int bytes);
-
-protected:
- virtual void OnRecv(SocketEvent_TCP* eventHandler, uint8_t* pData,
- int bytesRead);
-
-public:
- bool IsConnecting; // Is in the process of connecting?
-};
-
-
-//-----------------------------------------------------------------------------
-// TCPSocketPollState
-
-// Polls multiple blocking TCP sockets at once
-class TCPSocketPollState
-{
- fd_set readFD, exceptionFD, writeFD;
- SocketHandle largestDescriptor;
-
-public:
- TCPSocketPollState();
- bool IsValid() const;
- void Add(TCPSocket* tcpSocket);
- bool Poll(long usec = 30000, long seconds = 0);
- void HandleEvent(TCPSocket* tcpSocket, SocketEvent_TCP* eventHandler);
-};
-
-
-}} // OVR::Net
-
-#endif
+/************************************************************************************ + +PublicHeader: n/a +Filename : OVR_Win32_Socket.h +Content : Windows-specific socket-based networking implementation +Created : June 10, 2014 +Authors : Kevin Jenkins + +Copyright : Copyright 2014 Oculus VR, LLC All Rights reserved. + +Licensed under the Oculus VR Rift SDK License Version 3.2 (the "License"); +you may not use the Oculus VR Rift SDK except in compliance with the License, +which is provided at the time of installation or download, or which +otherwise accompanies this software in either electronic or hard copy form. + +You may obtain a copy of the License at + +http://www.oculusvr.com/licenses/LICENSE-3.2 + +Unless required by applicable law or agreed to in writing, the Oculus VR SDK +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +************************************************************************************/ + +#ifndef OVR_Win32_Socket_h +#define OVR_Win32_Socket_h + +#include "OVR_Socket.h" +#include "OVR_BitStream.h" + +#include <WinSock2.h> +#include <WS2tcpip.h> +#include "Kernel/OVR_Win32_IncludeWindows.h" +#include <io.h> + +namespace OVR { namespace Net { + + +//----------------------------------------------------------------------------- +// SockAddr + +// Abstraction for IPV6 socket address, with various convenience functions +class SockAddr +{ +public: + SockAddr(); + SockAddr(SockAddr* sa); + SockAddr(sockaddr_storage* sa); + SockAddr(sockaddr_in6* sa); + SockAddr(const char* hostAddress, uint16_t port, int sockType); + +public: + void Set(const sockaddr_storage* sa); + void Set(const sockaddr_in6* sa); + void Set(const char* hostAddress, uint16_t port, int sockType); // SOCK_DGRAM or SOCK_STREAM + + uint16_t GetPort(); + + String ToString(bool writePort, char portDelineator) const; + bool IsLocalhost() const; + + void Serialize(BitStream* bs); + bool Deserialize(BitStream); + + bool operator==( const SockAddr& right ) const; + bool operator!=( const SockAddr& right ) const; + bool operator >( const SockAddr& right ) const; + bool operator <( const SockAddr& right ) const; + +public: + sockaddr_in6 Addr6; +}; + + +//----------------------------------------------------------------------------- +// UDP Socket + +// Windows version of TCP socket +class UDPSocket : public UDPSocketBase +{ +public: + UDPSocket(); + virtual ~UDPSocket(); + +public: + virtual SocketHandle Bind(BerkleyBindParameters* pBindParameters); + virtual int Send(const void* pData, int bytes, SockAddr* address); + virtual void Poll(SocketEvent_UDP* eventHandler); + +protected: + static const int RecvBufSize = 1048576; + uint8_t* RecvBuf; + + virtual void OnRecv(SocketEvent_UDP* eventHandler, uint8_t* pData, + int bytesRead, SockAddr* address); +}; + + +//----------------------------------------------------------------------------- +// TCP Socket + +// Windows version of TCP socket +class TCPSocket : public TCPSocketBase +{ + friend class TCPSocketPollState; + +public: + TCPSocket(); + TCPSocket(SocketHandle boundHandle, bool isListenSocket); + virtual ~TCPSocket(); + +public: + virtual SocketHandle Bind(BerkleyBindParameters* pBindParameters); + virtual int Listen(); + virtual int Connect(SockAddr* address); + virtual int Send(const void* pData, int bytes); + +protected: + virtual void OnRecv(SocketEvent_TCP* eventHandler, uint8_t* pData, + int bytesRead); + +public: + bool IsConnecting; // Is in the process of connecting? +}; + + +//----------------------------------------------------------------------------- +// TCPSocketPollState + +// Polls multiple blocking TCP sockets at once +class TCPSocketPollState +{ + fd_set readFD, exceptionFD, writeFD; + SocketHandle largestDescriptor; + +public: + TCPSocketPollState(); + bool IsValid() const; + void Add(TCPSocket* tcpSocket); + bool Poll(long usec = 30000, long seconds = 0); + void HandleEvent(TCPSocket* tcpSocket, SocketEvent_TCP* eventHandler); +}; + + +}} // OVR::Net + +#endif |