root/oscpack/tags/release_1_0_2/ip/win32/UdpSocket.cpp

Revision 36, 14.6 kB (checked in by ross, 3 years ago)

fixed bug where RunUntilSigInt?() could only be called once, second time caused assertion failure

  • Property svn:eol-style set to native
Line 
1 /*
2         oscpack -- Open Sound Control packet manipulation library
3         http://www.audiomulch.com/~rossb/oscpack
4
5         Copyright (c) 2004-2005 Ross Bencina <rossb@audiomulch.com>
6
7         Permission is hereby granted, free of charge, to any person obtaining
8         a copy of this software and associated documentation files
9         (the "Software"), to deal in the Software without restriction,
10         including without limitation the rights to use, copy, modify, merge,
11         publish, distribute, sublicense, and/or sell copies of the Software,
12         and to permit persons to whom the Software is furnished to do so,
13         subject to the following conditions:
14
15         The above copyright notice and this permission notice shall be
16         included in all copies or substantial portions of the Software.
17
18         Any person wishing to distribute modifications to the Software is
19         requested to send the modifications to the original developer so that
20         they can be incorporated into the canonical version.
21
22         THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
23         EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
24         MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
25         IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR
26         ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
27         CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
28         WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
29 */
30 #include "ip/UdpSocket.h"
31
32 #include <winsock2.h>   // this must come first to prevent errors with MSVC7
33 #include <windows.h>
34 #include <mmsystem.h>   // for timeGetTime()
35
36 #include <vector>
37 #include <algorithm>
38 #include <stdexcept>
39 #include <assert.h>
40 #include <signal.h>
41
42 #include "ip/NetworkingUtils.h"
43 #include "ip/PacketListener.h"
44 #include "ip/TimerListener.h"
45
46
47 typedef int socklen_t;
48
49
50 static void SockaddrFromIpEndpointName( struct sockaddr_in& sockAddr, const IpEndpointName& endpoint )
51 {
52     memset( (char *)&sockAddr, 0, sizeof(sockAddr ) );
53     sockAddr.sin_family = AF_INET;
54
55         sockAddr.sin_addr.s_addr =
56                 (endpoint.address == IpEndpointName::ANY_ADDRESS)
57                 ? INADDR_ANY
58                 : htonl( endpoint.address );
59
60         sockAddr.sin_port =
61                 (endpoint.port == IpEndpointName::ANY_PORT)
62                 ? (short)0
63                 : htons( (short)endpoint.port );
64 }
65
66
67 static IpEndpointName IpEndpointNameFromSockaddr( const struct sockaddr_in& sockAddr )
68 {
69         return IpEndpointName(
70                 (sockAddr.sin_addr.s_addr == INADDR_ANY)
71                         ? IpEndpointName::ANY_ADDRESS
72                         : ntohl( sockAddr.sin_addr.s_addr ),
73                 (sockAddr.sin_port == 0)
74                         ? IpEndpointName::ANY_PORT
75                         : ntohs( sockAddr.sin_port )
76                 );
77 }
78
79
80 class UdpSocket::Implementation{
81     NetworkInitializer networkInitializer_;
82
83         bool isBound_;
84         bool isConnected_;
85
86         SOCKET socket_;
87         struct sockaddr_in connectedAddr_;
88         struct sockaddr_in sendToAddr_;
89
90 public:
91
92         Implementation()
93                 : isBound_( false )
94                 , isConnected_( false )
95                 , socket_( INVALID_SOCKET )
96         {
97                 if( (socket_ = socket( AF_INET, SOCK_DGRAM, 0 )) == INVALID_SOCKET ){
98             throw std::runtime_error("unable to create udp socket\n");
99         }
100
101                 memset( &sendToAddr_, 0, sizeof(sendToAddr_) );
102         sendToAddr_.sin_family = AF_INET;
103         }
104
105         ~Implementation()
106         {
107                 if (socket_ != INVALID_SOCKET) closesocket(socket_);
108         }
109
110         IpEndpointName LocalEndpointFor( const IpEndpointName& remoteEndpoint ) const
111         {
112                 assert( isBound_ );
113
114                 // first connect the socket to the remote server
115         
116         struct sockaddr_in connectSockAddr;
117                 SockaddrFromIpEndpointName( connectSockAddr, remoteEndpoint );
118        
119         if (connect(socket_, (struct sockaddr *)&connectSockAddr, sizeof(connectSockAddr)) < 0) {
120             throw std::runtime_error("unable to connect udp socket\n");
121         }
122
123         // get the address
124
125         struct sockaddr_in sockAddr;
126         memset( (char *)&sockAddr, 0, sizeof(sockAddr ) );
127         socklen_t length = sizeof(sockAddr);
128         if (getsockname(socket_, (struct sockaddr *)&sockAddr, &length) < 0) {
129             throw std::runtime_error("unable to getsockname\n");
130         }
131        
132                 if( isConnected_ ){
133                         // reconnect to the connected address
134                         
135                         if (connect(socket_, (struct sockaddr *)&connectedAddr_, sizeof(connectedAddr_)) < 0) {
136                                 throw std::runtime_error("unable to connect udp socket\n");
137                         }
138
139                 }else{
140                         // unconnect from the remote address
141                 
142                         struct sockaddr_in unconnectSockAddr;
143                         SockaddrFromIpEndpointName( unconnectSockAddr, IpEndpointName() );
144
145                         if( connect(socket_, (struct sockaddr *)&unconnectSockAddr, sizeof(unconnectSockAddr)) < 0
146                                         && WSAGetLastError() != WSAEADDRNOTAVAIL ){
147                                 throw std::runtime_error("unable to un-connect udp socket\n");
148                         }
149                 }
150
151                 return IpEndpointNameFromSockaddr( sockAddr );
152         }
153
154         void Connect( const IpEndpointName& remoteEndpoint )
155         {
156                 SockaddrFromIpEndpointName( connectedAddr_, remoteEndpoint );
157        
158         if (connect(socket_, (struct sockaddr *)&connectedAddr_, sizeof(connectedAddr_)) < 0) {
159             throw std::runtime_error("unable to connect udp socket\n");
160         }
161
162                 isConnected_ = true;
163         }
164
165         void Send( const char *data, int size )
166         {
167                 assert( isConnected_ );
168
169         send( socket_, data, size, 0 );
170         }
171
172     void SendTo( const IpEndpointName& remoteEndpoint, const char *data, int size )
173         {
174                 sendToAddr_.sin_addr.s_addr = htonl( remoteEndpoint.address );
175         sendToAddr_.sin_port = htons( (short)remoteEndpoint.port );
176
177         sendto( socket_, data, size, 0, (sockaddr*)&sendToAddr_, sizeof(sendToAddr_) );
178         }
179
180         void Bind( const IpEndpointName& localEndpoint )
181         {
182                 struct sockaddr_in bindSockAddr;
183                 SockaddrFromIpEndpointName( bindSockAddr, localEndpoint );
184
185         if (bind(socket_, (struct sockaddr *)&bindSockAddr, sizeof(bindSockAddr)) < 0) {
186             throw std::runtime_error("unable to bind udp socket\n");
187         }
188
189                 isBound_ = true;
190         }
191
192         bool IsBound() const { return isBound_; }
193
194     int ReceiveFrom( IpEndpointName& remoteEndpoint, char *data, int size )
195         {
196                 assert( isBound_ );
197
198                 struct sockaddr_in fromAddr;
199         socklen_t fromAddrLen = sizeof(fromAddr);
200                  
201         int result = recvfrom(socket_, data, size, 0,
202                     (struct sockaddr *) &fromAddr, (socklen_t*)&fromAddrLen);
203                 if( result < 0 )
204                         return 0;
205
206                 remoteEndpoint.address = ntohl(fromAddr.sin_addr.s_addr);
207                 remoteEndpoint.port = ntohs(fromAddr.sin_port);
208
209                 return result;
210         }
211
212         SOCKET& Socket() { return socket_; }
213 };
214
215 UdpSocket::UdpSocket()
216 {
217         impl_ = new Implementation();
218 }
219
220 UdpSocket::~UdpSocket()
221 {
222         delete impl_;
223 }
224
225 IpEndpointName UdpSocket::LocalEndpointFor( const IpEndpointName& remoteEndpoint ) const
226 {
227         return impl_->LocalEndpointFor( remoteEndpoint );
228 }
229
230 void UdpSocket::Connect( const IpEndpointName& remoteEndpoint )
231 {
232         impl_->Connect( remoteEndpoint );
233 }
234
235 void UdpSocket::Send( const char *data, int size )
236 {
237         impl_->Send( data, size );
238 }
239
240 void UdpSocket::SendTo( const IpEndpointName& remoteEndpoint, const char *data, int size )
241 {
242         impl_->SendTo( remoteEndpoint, data, size );
243 }
244
245 void UdpSocket::Bind( const IpEndpointName& localEndpoint )
246 {
247         impl_->Bind( localEndpoint );
248 }
249
250 bool UdpSocket::IsBound() const
251 {
252         return impl_->IsBound();
253 }
254
255 int UdpSocket::ReceiveFrom( IpEndpointName& remoteEndpoint, char *data, int size )
256 {
257         return impl_->ReceiveFrom( remoteEndpoint, data, size );
258 }
259
260
261 struct AttachedTimerListener{
262         AttachedTimerListener( int id, int p, TimerListener *tl )
263                 : initialDelayMs( id )
264                 , periodMs( p )
265                 , listener( tl ) {}
266         int initialDelayMs;
267         int periodMs;
268         TimerListener *listener;
269 };
270
271
272 static bool CompareScheduledTimerCalls(
273                 const std::pair< double, AttachedTimerListener > & lhs, const std::pair< double, AttachedTimerListener > & rhs )
274 {
275         return lhs.first < rhs.first;
276 }
277
278
279 SocketReceiveMultiplexer *multiplexerInstanceToAbortWithSigInt_ = 0;
280
281 extern "C" /*static*/ void InterruptSignalHandler( int );
282 /*static*/ void InterruptSignalHandler( int )
283 {
284         multiplexerInstanceToAbortWithSigInt_->AsynchronousBreak();
285         signal( SIGINT, SIG_DFL );
286 }
287
288
289 class SocketReceiveMultiplexer::Implementation{
290     NetworkInitializer networkInitializer_;
291
292         std::vector< std::pair< PacketListener*, UdpSocket* > > socketListeners_;
293         std::vector< AttachedTimerListener > timerListeners_;
294
295         volatile bool break_;
296         HANDLE breakEvent_;
297
298         double GetCurrentTimeMs() const
299         {
300                 return timeGetTime(); // FIXME: bad choice if you want to run for more than 40 days
301         }
302
303 public:
304     Implementation()
305         {
306                 breakEvent_ = CreateEvent( NULL, FALSE, FALSE, NULL );
307         }
308
309     ~Implementation()
310         {
311                 CloseHandle( breakEvent_ );
312         }
313
314     void AttachSocketListener( UdpSocket *socket, PacketListener *listener )
315         {
316                 assert( std::find( socketListeners_.begin(), socketListeners_.end(), std::make_pair(listener, socket) ) == socketListeners_.end() );
317                 // we don't check that the same socket has been added multiple times, even though this is an error
318                 socketListeners_.push_back( std::make_pair( listener, socket ) );
319         }
320
321     void DetachSocketListener( UdpSocket *socket, PacketListener *listener )
322         {
323                 std::vector< std::pair< PacketListener*, UdpSocket* > >::iterator i =
324                                 std::find( socketListeners_.begin(), socketListeners_.end(), std::make_pair(listener, socket) );
325                 assert( i != socketListeners_.end() );
326
327                 socketListeners_.erase( i );
328         }
329
330     void AttachPeriodicTimerListener( int periodMilliseconds, TimerListener *listener )
331         {
332                 timerListeners_.push_back( AttachedTimerListener( periodMilliseconds, periodMilliseconds, listener ) );
333         }
334
335         void AttachPeriodicTimerListener( int initialDelayMilliseconds, int periodMilliseconds, TimerListener *listener )
336         {
337                 timerListeners_.push_back( AttachedTimerListener( initialDelayMilliseconds, periodMilliseconds, listener ) );
338         }
339
340     void DetachPeriodicTimerListener( TimerListener *listener )
341         {
342                 std::vector< AttachedTimerListener >::iterator i = timerListeners_.begin();
343                 while( i != timerListeners_.end() ){
344                         if( i->listener == listener )
345                                 break;
346                         ++i;
347                 }
348
349                 assert( i != timerListeners_.end() );
350
351                 timerListeners_.erase( i );
352         }
353
354     void Run()
355         {
356                 break_ = false;
357
358                 // prepare the window events which we use to wake up on incoming data
359                 // we use this instead of select() primarily to support the AsyncBreak()
360                 // mechanism.
361
362                 std::vector<HANDLE> events( socketListeners_.size() + 1, 0 );
363                 int j=0;
364                 for( std::vector< std::pair< PacketListener*, UdpSocket* > >::iterator i = socketListeners_.begin();
365                                 i != socketListeners_.end(); ++i, ++j ){
366
367                         HANDLE event = CreateEvent( NULL, FALSE, FALSE, NULL );
368                         WSAEventSelect( i->second->impl_->Socket(), event, FD_READ ); // note that this makes the socket non-blocking which is why we can safely call RecieveFrom() on all sockets below
369                         events[j] = event;
370                 }
371
372
373                 events[ socketListeners_.size() ] = breakEvent_; // last event in the collection is the break event
374
375                
376                 // configure the timer queue
377                 double currentTimeMs = GetCurrentTimeMs();
378
379                 // expiry time ms, listener
380                 std::vector< std::pair< double, AttachedTimerListener > > timerQueue_;
381                 for( std::vector< AttachedTimerListener >::iterator i = timerListeners_.begin();
382                                 i != timerListeners_.end(); ++i )
383                         timerQueue_.push_back( std::make_pair( currentTimeMs + i->initialDelayMs, *i ) );
384                 std::sort( timerQueue_.begin(), timerQueue_.end(), CompareScheduledTimerCalls );
385
386                 const int MAX_BUFFER_SIZE = 4098;
387                 char *data = new char[ MAX_BUFFER_SIZE ];
388                 IpEndpointName remoteEndpoint;
389
390                 while( !break_ ){
391
392                         double currentTimeMs = GetCurrentTimeMs();
393
394             DWORD waitTime = INFINITE;
395             if( !timerQueue_.empty() ){
396
397                 waitTime = (DWORD)( timerQueue_.front().first >= currentTimeMs
398                             ? timerQueue_.front().first - currentTimeMs
399                             : 0 );
400             }
401
402                         DWORD waitResult = WaitForMultipleObjects( (DWORD)socketListeners_.size() + 1, &events[0], FALSE, waitTime );
403                         if( break_ )
404                                 break;
405
406                         if( waitResult != WAIT_TIMEOUT ){
407                                 for( int i = waitResult - WAIT_OBJECT_0; i < (int)socketListeners_.size(); ++i ){
408                                         int size = socketListeners_[i].second->ReceiveFrom( remoteEndpoint, data, MAX_BUFFER_SIZE );
409                                         if( size > 0 ){
410                                                 socketListeners_[i].first->ProcessPacket( data, size, remoteEndpoint );
411                                                 if( break_ )
412                                                         break;
413                                         }
414                                 }
415                         }
416
417                         // execute any expired timers
418                         currentTimeMs = GetCurrentTimeMs();
419                         bool resort = false;
420                         for( std::vector< std::pair< double, AttachedTimerListener > >::iterator i = timerQueue_.begin();
421                                         i != timerQueue_.end() && i->first <= currentTimeMs; ++i ){
422
423                                 i->second.listener->TimerExpired();
424                                 if( break_ )
425                                         break;
426
427                                 i->first += i->second.periodMs;
428                                 resort = true;
429                         }
430                         if( resort )
431                                 std::sort( timerQueue_.begin(), timerQueue_.end(), CompareScheduledTimerCalls );
432                 }
433
434                 delete [] data;
435
436                 // free events
437                 j = 0;
438                 for( std::vector< std::pair< PacketListener*, UdpSocket* > >::iterator i = socketListeners_.begin();
439                                 i != socketListeners_.end(); ++i, ++j ){
440
441                         WSAEventSelect( i->second->impl_->Socket(), events[j], 0 ); // remove association between socket and event
442                         CloseHandle( events[j] );
443                         unsigned long enableNonblocking = 0;
444                         ioctlsocket( i->second->impl_->Socket(), FIONBIO, &enableNonblocking );  // make the socket blocking again
445                 }
446         }
447
448     void Break()
449         {
450                 break_ = true;
451         }
452
453     void AsynchronousBreak()
454         {
455                 break_ = true;
456                 SetEvent( breakEvent_ );
457         }
458 };
459
460
461
462 SocketReceiveMultiplexer::SocketReceiveMultiplexer()
463 {
464         impl_ = new Implementation();
465 }
466
467 SocketReceiveMultiplexer::~SocketReceiveMultiplexer()
468 {       
469         delete impl_;
470 }
471
472 void SocketReceiveMultiplexer::AttachSocketListener( UdpSocket *socket, PacketListener *listener )
473 {
474         impl_->AttachSocketListener( socket, listener );
475 }
476
477 void SocketReceiveMultiplexer::DetachSocketListener( UdpSocket *socket, PacketListener *listener )
478 {
479         impl_->DetachSocketListener( socket, listener );
480 }
481
482 void SocketReceiveMultiplexer::AttachPeriodicTimerListener( int periodMilliseconds, TimerListener *listener )
483 {
484         impl_->AttachPeriodicTimerListener( periodMilliseconds, listener );
485 }
486
487 void SocketReceiveMultiplexer::AttachPeriodicTimerListener( int initialDelayMilliseconds, int periodMilliseconds, TimerListener *listener )
488 {
489         impl_->AttachPeriodicTimerListener( initialDelayMilliseconds, periodMilliseconds, listener );
490 }
491
492 void SocketReceiveMultiplexer::DetachPeriodicTimerListener( TimerListener *listener )
493 {
494         impl_->DetachPeriodicTimerListener( listener );
495 }
496
497 void SocketReceiveMultiplexer::Run()
498 {
499         impl_->Run();
500 }
501
502 void SocketReceiveMultiplexer::RunUntilSigInt()
503 {
504         assert( multiplexerInstanceToAbortWithSigInt_ == 0 ); /* at present we support only one multiplexer instance running until sig int */
505         multiplexerInstanceToAbortWithSigInt_ = this;
506         signal( SIGINT, InterruptSignalHandler );
507         impl_->Run();
508         signal( SIGINT, SIG_DFL );
509         multiplexerInstanceToAbortWithSigInt_ = 0;
510 }
511
512 void SocketReceiveMultiplexer::Break()
513 {
514         impl_->Break();
515 }
516
517 void SocketReceiveMultiplexer::AsynchronousBreak()
518 {
519         impl_->AsynchronousBreak();
520 }
521
Note: See TracBrowser for help on using the browser.