Jo Shields a575963da9 Imported Upstream version 3.6.0
Former-commit-id: da6be194a6b1221998fc28233f2503bd61dd9d14
2014-08-13 10:39:27 +01:00

499 lines
14 KiB
C#

//
// Mono.Net.Dns.SimpleResolver
//
// Authors:
// Gonzalo Paniagua Javier (gonzalo.mono@gmail.com)
//
// Copyright 2011 Gonzalo Paniagua Javier
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Diagnostics;
using System.Net;
using System.Net.Sockets;
using System.Net.NetworkInformation;
using System.Text;
using System.Threading;
namespace Mono.Net.Dns {
sealed class SimpleResolver : IDisposable {
static string [] EmptyStrings = new string [0];
static IPAddress [] EmptyAddresses = new IPAddress [0];
IPEndPoint [] endpoints;
Socket client;
Dictionary<int, SimpleResolverEventArgs> queries;
AsyncCallback receive_cb;
TimerCallback timeout_cb;
bool disposed;
#if REUSE_RESPONSES
Stack<DnsResponse> responses_avail = new Stack<DnsResponse> ();
#endif
public SimpleResolver ()
{
queries = new Dictionary<int, SimpleResolverEventArgs> ();
receive_cb = new AsyncCallback (OnReceive);
timeout_cb = new TimerCallback (OnTimeout);
InitFromSystem ();
InitSocket ();
}
void IDisposable.Dispose ()
{
if (!disposed) {
disposed = true;
if (client != null) {
client.Close ();
client = null;
}
}
}
public void Close ()
{
((IDisposable) this).Dispose ();
}
void GetLocalHost (SimpleResolverEventArgs args)
{
//FIXME
IPHostEntry entry = new IPHostEntry ();
entry.HostName = "localhost";
entry.AddressList = new IPAddress [] { IPAddress.Loopback };
entry.Aliases = EmptyStrings;
args.ResolverError = 0;
args.HostEntry = entry;
return;
/*
List<IPEndPoint> eps = new List<IPEndPoint> ();
foreach (NetworkInterface iface in NetworkInterface.GetAllNetworkInterfaces ()) {
if (NetworkInterfaceType.Loopback == iface.NetworkInterfaceType)
continue;
foreach (IPAddress addr in iface.GetIPProperties ().DnsAddresses) {
if (AddressFamily.InterNetworkV6 == addr.AddressFamily)
continue;
IPEndPoint ep = new IPEndPoint (addr, 53);
if (eps.Contains (ep))
continue;
eps.Add (ep);
}
}
endpoints = eps.ToArray ();
*/
}
// Type A query
// Might fill in Aliases
// -IPAddress -> return the same IPAddress
// -"" -> Local host ip addresses (filter out IPv6 if needed)
public bool GetHostAddressesAsync (SimpleResolverEventArgs args)
{
if (args == null)
throw new ArgumentNullException ("args");
if (args.HostName == null)
throw new ArgumentNullException ("args.HostName is null");
if (args.HostName.Length > 255)
throw new ArgumentException ("args.HostName is too long");
args.Reset (ResolverAsyncOperation.GetHostAddresses);
string host = args.HostName;
if (host == "") {
GetLocalHost (args);
return false;
}
IPAddress addr;
if (IPAddress.TryParse (host, out addr)) {
IPHostEntry entry = new IPHostEntry ();
entry.HostName = host;
entry.Aliases = EmptyStrings;
entry.AddressList = new IPAddress [1] { addr };
args.HostEntry = entry;
return false;
}
SendAQuery (args, true);
return true;
}
// For names -> type A Query
// For IP addresses -> PTR + A -> will at least return itself
// Careful: for IP addresses with PTR, the hostname might yield different IP addresses!
public bool GetHostEntryAsync (SimpleResolverEventArgs args)
{
if (args == null)
throw new ArgumentNullException ("args");
if (args.HostName == null)
throw new ArgumentNullException ("args.HostName is null");
if (args.HostName.Length > 255)
throw new ArgumentException ("args.HostName is too long");
args.Reset (ResolverAsyncOperation.GetHostEntry);
string host = args.HostName;
if (host == "") {
GetLocalHost (args);
return false;
}
IPAddress addr;
if (IPAddress.TryParse (host, out addr)) {
IPHostEntry entry = new IPHostEntry ();
entry.HostName = host;
entry.Aliases = EmptyStrings;
entry.AddressList = new IPAddress [1] { addr };
args.HostEntry = entry;
args.PTRAddress = addr;
SendPTRQuery (args, true);
return true;
}
// 3. For IP addresses:
// 3.1 Parsing IP succeeds
// 3.2 Reverse lookup of the IP fills in HostName -> fails? HostName = IP
// 3.3 The hostname resulting from this is used to query DNS again to get the IP addresses
//
// Exclude IPv6 addresses if not supported by the system
// .Aliases is always empty
// Length > 255
SendAQuery (args, true);
return true;
}
bool AddQuery (DnsQuery query, SimpleResolverEventArgs args)
{
lock (queries) {
if (queries.ContainsKey (query.Header.ID))
return false;
queries [query.Header.ID] = args;
}
return true;
}
static DnsQuery GetQuery (string host, DnsQType q, DnsQClass c)
{
return new DnsQuery (host, q, c);
}
void SendAQuery (SimpleResolverEventArgs args, bool add_it)
{
SendAQuery (args, args.HostName, add_it);
}
void SendAQuery (SimpleResolverEventArgs args, string host, bool add_it)
{
DnsQuery query = GetQuery (host, DnsQType.A, DnsQClass.IN);
SendQuery (args, query, add_it);
}
static string GetPTRName (IPAddress address)
{
// TODO: IPv6 PTR query?
byte [] bytes = address.GetAddressBytes ();
// "XXX.XXX.XXX.XXX.in-addr.arpa".Length
StringBuilder sb = new StringBuilder (28);
for (int i = bytes.Length - 1; i >= 0; i--) {
sb.AppendFormat ("{0}.", bytes [i]);
}
sb.Append ("in-addr.arpa");
return sb.ToString ();
}
void SendPTRQuery (SimpleResolverEventArgs args, bool add_it)
{
DnsQuery query = GetQuery (GetPTRName (args.PTRAddress), DnsQType.PTR, DnsQClass.IN);
SendQuery (args, query, add_it);
}
void SendQuery (SimpleResolverEventArgs args, DnsQuery query, bool add_it)
{
// TODO: not sure about reusing IDs when add_it == false
int count = 0;
if (add_it) {
do {
query.Header.ID = (ushort)new Random().Next(1, 65534);
if (count > 500)
throw new InvalidOperationException ("Too many pending queries (or really bad luck)");
} while (AddQuery (query, args) == false);
args.QueryID = query.Header.ID;
} else {
query.Header.ID = args.QueryID;
}
if (args.Timer == null)
args.Timer = new Timer (timeout_cb, args, 5000, Timeout.Infinite);
else
args.Timer.Change (5000, Timeout.Infinite);
client.BeginSend (query.Packet, 0, query.Length, SocketFlags.None, null, null);
}
byte [] GetFreshBuffer ()
{
#if !REUSE_RESPONSES
return new byte [512];
#else
DnsResponse response = null;
lock (responses_avail) {
if (responses_avail.Count > 0) {
response = responses_avail.Pop ();
}
}
if (response == null) {
response = new DnsResponse ();
} else {
response.Reset ();
}
return response;
#endif
}
void FreeBuffer (byte [] buffer)
{
#if REUSE_RESPONSES
// TODO: set some limit here. Configurable?
lock (responses_avail) {
responses_avail.Push (response);
}
#endif
}
void InitSocket ()
{
client = new Socket (AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
client.Blocking = true;
client.Bind (new IPEndPoint (IPAddress.Any, 0));
client.Connect (endpoints [0]);
BeginReceive ();
}
void BeginReceive ()
{
byte [] buffer = GetFreshBuffer ();
client.BeginReceive (buffer, 0, buffer.Length, SocketFlags.None, receive_cb, buffer);
}
void OnTimeout (object obj)
{
SimpleResolverEventArgs args = (SimpleResolverEventArgs) obj;
SimpleResolverEventArgs args2;
lock (queries) {
if (!queries.TryGetValue (args.QueryID, out args2)) {
return; // Already processed.
}
if (args != args2)
throw new Exception ("Should not happen: args != args2");
args.Retries++;
if (args.Retries > 1) {
// Error timeout
args.ResolverError = ResolverError.Timeout;
args.OnCompleted (this);
} else {
SendAQuery (args, false);
}
}
}
void OnReceive (IAsyncResult ares)
{
if (disposed)
return;
int nread = 0;
EndPoint remote_ep = client.RemoteEndPoint;
try {
nread = client.EndReceive (ares);
} catch (Exception e) {
Console.Error.WriteLine (e);
}
BeginReceive ();
byte [] buffer = (byte []) ares.AsyncState;
if (nread > 12) {
DnsResponse response = new DnsResponse (buffer, nread);
int id = response.Header.ID;
SimpleResolverEventArgs args = null;
lock (queries) {
if (queries.TryGetValue (id, out args)) {
queries.Remove (id);
}
}
if (args != null) {
Timer t = args.Timer;
if (t != null)
t.Change (Timeout.Infinite, Timeout.Infinite);
try {
ProcessResponse (args, response, remote_ep);
} catch (Exception e) {
args.ResolverError = (ResolverError) (-1);
args.ErrorMessage = e.Message;
}
IPHostEntry entry = args.HostEntry;
if (args.ResolverError != 0 && args.PTRAddress != null && entry != null && entry.HostName != null) {
args.PTRAddress = null;
SendAQuery (args, entry.HostName, true);
args.Timer.Change (5000, Timeout.Infinite);
} else {
args.OnCompleted (this);
}
}
}
FreeBuffer (buffer);
}
void ProcessResponse (SimpleResolverEventArgs args, DnsResponse response, EndPoint server_ep)
{
DnsRCode status = response.Header.RCode;
if (status != 0) {
if (args.PTRAddress != null) {
// PTR query failed -> no error, we have the IP
return;
}
args.ResolverError = (ResolverError) status;
return;
}
// TODO: verify IP of the server is in our list and the same one that got the query
IPEndPoint ep = (IPEndPoint) server_ep;
if (ep.Port != 53) {
args.ResolverError = ResolverError.ResponseHeaderError;
args.ErrorMessage = "Port";
return;
}
DnsHeader header = response.Header;
if (!header.IsQuery) {
args.ResolverError = ResolverError.ResponseHeaderError;
args.ErrorMessage = "IsQuery";
return;
}
// TODO: handle Truncation. Retry with bigger buffer?
if (header.QuestionCount > 1) {
args.ResolverError = ResolverError.ResponseHeaderError;
args.ErrorMessage = "QuestionCount";
return;
}
ReadOnlyCollection<DnsQuestion> q = response.GetQuestions ();
if (q.Count != 1) {
args.ResolverError = ResolverError.ResponseHeaderError;
args.ErrorMessage = "QuestionCount 2";
return;
}
DnsQuestion question = q [0];
/* The answer might have dot at the end, etc...
if (String.Compare (question.Name, args.HostName) != 0) {
args.ResolverError = ResolverError.ResponseHeaderError;
args.ErrorMessage = "HostName - " + question.Name + " != " + args.HostName;
return;
}
*/
DnsQType t = question.Type;
if (t != DnsQType.A && t != DnsQType.AAAA && t != DnsQType.PTR) {
args.ResolverError = ResolverError.ResponseHeaderError;
args.ErrorMessage = "QType " + question.Type;
return;
}
if (question.Class != DnsQClass.IN) {
args.ResolverError = ResolverError.ResponseHeaderError;
args.ErrorMessage = "QClass " + question.Class;
return;
}
ReadOnlyCollection<DnsResourceRecord> records = response.GetAnswers ();
if (records.Count == 0) {
if (args.PTRAddress != null) {
// PTR query failed -> no error
return;
}
args.ResolverError = ResolverError.NameError; // is this ok?
args.ErrorMessage = "NoAnswers";
return;
}
List<string> aliases = null;
List<IPAddress> addresses = null;
foreach (DnsResourceRecord r in records) {
if (r.Class != DnsClass.IN)
continue;
if (r.Type == DnsType.A || r.Type == DnsType.AAAA) {
if (addresses == null)
addresses = new List<IPAddress> ();
addresses.Add (((DnsResourceRecordIPAddress) r).Address);
} else if (r.Type == DnsType.CNAME) {
if (aliases == null)
aliases = new List<string> ();
aliases.Add (((DnsResourceRecordCName) r).CName);
} else if (r.Type == DnsType.PTR) {
args.HostEntry.HostName = ((DnsResourceRecordPTR) r).DName;
args.HostEntry.Aliases = aliases == null ? EmptyStrings : aliases.ToArray ();
args.HostEntry.AddressList = EmptyAddresses;
return;
}
}
IPHostEntry entry = args.HostEntry ?? new IPHostEntry ();
if (entry.HostName == null && aliases != null && aliases.Count > 0) {
entry.HostName = aliases [0];
aliases.RemoveAt (0);
}
entry.Aliases = aliases == null ? EmptyStrings : aliases.ToArray ();
entry.AddressList = addresses == null ? EmptyAddresses : addresses.ToArray ();
args.HostEntry = entry;
if ((question.Type == DnsQType.A || question.Type == DnsQType.AAAA) && entry.AddressList == EmptyAddresses) {
args.ResolverError = ResolverError.NameError;
args.ErrorMessage = "No addresses in response";
} else if (question.Type == DnsQType.PTR && entry.HostName == null) {
args.ResolverError = ResolverError.NameError;
args.ErrorMessage = "No PTR in response";
}
}
void InitFromSystem ()
{
List<IPEndPoint> eps = new List<IPEndPoint> ();
foreach (NetworkInterface iface in NetworkInterface.GetAllNetworkInterfaces ()) {
if (NetworkInterfaceType.Loopback == iface.NetworkInterfaceType)
continue;
foreach (IPAddress addr in iface.GetIPProperties ().DnsAddresses) {
if (AddressFamily.InterNetworkV6 == addr.AddressFamily)
continue;
IPEndPoint ep = new IPEndPoint (addr, 53);
if (eps.Contains (ep))
continue;
eps.Add (ep);
}
}
endpoints = eps.ToArray ();
}
}
}