// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.txt in the project root for license information. namespace System.Data.Entity { using System.Collections; using System.Collections.Generic; using System.Collections.ObjectModel; using System.Linq; using System.Linq.Expressions; using System.Threading; using System.Threading.Tasks; /// /// In-memory implementation of IDbSet based on a /// /// Type of elements to be stored in the set public class HashSetBasedDbSet : IDbSet where T : class, new() { private readonly HashSet _data; private readonly IQueryable _query; private readonly Func, T> _findFunc; public HashSetBasedDbSet() : this(null) { } public HashSetBasedDbSet(Func, T> findFunc) { _data = new HashSet(); _query = _data.AsQueryable(); _findFunc = findFunc; } public T Find(params object[] keyValues) { if (_findFunc == null) { throw new NotSupportedException("If you want to call find then use the constructor that specifies a find func."); } return _findFunc(_data); } public Task FindAsync(CancellationToken cancellationToken, params object[] keyValues) { throw new NotImplementedException(); } public T Add(T item) { _data.Add(item); return item; } public T Remove(T item) { _data.Remove(item); return item; } public T Attach(T item) { _data.Add(item); return item; } Type IQueryable.ElementType { get { return _query.ElementType; } } Expression IQueryable.Expression { get { return _query.Expression; } } IQueryProvider IQueryable.Provider { get { return _query.Provider; } } IEnumerator IEnumerable.GetEnumerator() { return _data.GetEnumerator(); } IEnumerator IEnumerable.GetEnumerator() { return _data.GetEnumerator(); } public ObservableCollection Local { get { return new ObservableCollection(_data); } } public T Create() { return new T(); } public TDerivedEntity Create() where TDerivedEntity : class, T { throw new NotImplementedException(); } } }