diff --git a/usersession/client/client.go b/usersession/client/client.go index 602091f7ab..9554257fcc 100644 --- a/usersession/client/client.go +++ b/usersession/client/client.go @@ -52,6 +52,7 @@ func dialSessionAgent(network, address string) (net.Conn, error) { type Client struct { doer *http.Client + uids []int } func New() *Client { @@ -61,6 +62,14 @@ func New() *Client { } } +// NewForUids creates a Client that sends requests to a specific list of uids +// only. +func NewForUids(uids ...int) *Client { + cli := New() + cli.uids = append(cli.uids, uids...) + return cli +} + type Error struct { Kind string `json:"kind"` Value interface{} `json:"value"` @@ -132,11 +141,11 @@ func (client *Client) sendRequest(ctx context.Context, socket string, method, ur return response } -// doMany sends the given request to all active user sessions. Please be -// careful when using this method, because it is not aware of the physical user -// who triggered the request and blindly forwards it to all logged in users. -// Some of them might not have the right to see the request (let alone to -// respond to it). +// doMany sends the given request to all active user sessions or a subset of them +// defined by optional client.uids field. Please be careful when using this +// method, because it is not aware of the physical user who triggered the request +// and blindly forwards it to all logged in users. Some of them might not have +// the right to see the request (let alone to respond to it). func (client *Client) doMany(ctx context.Context, method, urlpath string, query url.Values, headers map[string]string, body []byte) ([]*response, error) { sockets, err := filepath.Glob(filepath.Join(dirs.XdgRuntimeDirGlob, "snapd-session-agent.socket")) if err != nil { @@ -147,7 +156,26 @@ func (client *Client) doMany(ctx context.Context, method, urlpath string, query mu sync.Mutex responses []*response ) + + var uids map[string]bool + if len(client.uids) > 0 { + uids = make(map[string]bool) + for _, uid := range client.uids { + uids[fmt.Sprintf("%d", uid)] = true + } + } + for _, socket := range sockets { + // filter out sockets based on uids + if len(uids) > 0 { + // XXX: alternatively we could Stat() the socket and + // and check Uid field of stat.Sys().(*syscall.Stat_t), but it's + // more annyoing to unit-test. + userPart := filepath.Base(filepath.Dir(socket)) + if !uids[userPart] { + continue + } + } wg.Add(1) go func(socket string) { defer wg.Done() diff --git a/usersession/client/client_test.go b/usersession/client/client_test.go index 1ee96531f7..299cf5c9b2 100644 --- a/usersession/client/client_test.go +++ b/usersession/client/client_test.go @@ -26,6 +26,7 @@ import ( "net/http" "os" "path/filepath" + "sync/atomic" "testing" "time" @@ -448,7 +449,9 @@ func (s *clientSuite) TestServicesStopFailure(c *C) { } func (s *clientSuite) TestPendingRefreshNotification(c *C) { + var n int32 s.handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&n, 1) c.Assert(r.URL.Path, Equals, "/v1/notifications/pending-refresh") w.Header().Set("Content-Type", "application/json") w.WriteHeader(200) @@ -456,4 +459,20 @@ func (s *clientSuite) TestPendingRefreshNotification(c *C) { }) err := s.cli.PendingRefreshNotification(context.Background(), &client.PendingSnapRefreshInfo{}) c.Assert(err, IsNil) + c.Check(atomic.LoadInt32(&n), Equals, int32(2)) +} + +func (s *clientSuite) TestPendingRefreshNotificationOneClient(c *C) { + cli := client.NewForUids(1000) + var n int32 + s.handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&n, 1) + c.Assert(r.URL.Path, Equals, "/v1/notifications/pending-refresh") + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + w.Write([]byte(`{"type": "sync"}`)) + }) + err := cli.PendingRefreshNotification(context.Background(), &client.PendingSnapRefreshInfo{}) + c.Assert(err, IsNil) + c.Check(atomic.LoadInt32(&n), Equals, int32(1)) }