From 2d53b5dc879dadea41f0f1ed59daa056722f8129 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Sto=C5=82owski?= Date: Wed, 17 Nov 2021 14:01:07 +0100 Subject: [PATCH 1/4] Provide a way for usersession client to send messages only to a subset of clients. --- usersession/client/client.go | 38 +++++++++++++++++++++++++++---- usersession/client/client_test.go | 18 +++++++++++++++ 2 files changed, 51 insertions(+), 5 deletions(-) diff --git a/usersession/client/client.go b/usersession/client/client.go index 602091f7ab..9554257fcc 100644 --- a/usersession/client/client.go +++ b/usersession/client/client.go @@ -52,6 +52,7 @@ func dialSessionAgent(network, address string) (net.Conn, error) { type Client struct { doer *http.Client + uids []int } func New() *Client { @@ -61,6 +62,14 @@ func New() *Client { } } +// NewForUids creates a Client that sends requests to a specific list of uids +// only. +func NewForUids(uids ...int) *Client { + cli := New() + cli.uids = append(cli.uids, uids...) + return cli +} + type Error struct { Kind string `json:"kind"` Value interface{} `json:"value"` @@ -132,11 +141,11 @@ func (client *Client) sendRequest(ctx context.Context, socket string, method, ur return response } -// doMany sends the given request to all active user sessions. Please be -// careful when using this method, because it is not aware of the physical user -// who triggered the request and blindly forwards it to all logged in users. -// Some of them might not have the right to see the request (let alone to -// respond to it). +// doMany sends the given request to all active user sessions or a subset of them +// defined by optional client.uids field. Please be careful when using this +// method, because it is not aware of the physical user who triggered the request +// and blindly forwards it to all logged in users. Some of them might not have +// the right to see the request (let alone to respond to it). func (client *Client) doMany(ctx context.Context, method, urlpath string, query url.Values, headers map[string]string, body []byte) ([]*response, error) { sockets, err := filepath.Glob(filepath.Join(dirs.XdgRuntimeDirGlob, "snapd-session-agent.socket")) if err != nil { @@ -147,7 +156,26 @@ func (client *Client) doMany(ctx context.Context, method, urlpath string, query mu sync.Mutex responses []*response ) + + var uids map[string]bool + if len(client.uids) > 0 { + uids = make(map[string]bool) + for _, uid := range client.uids { + uids[fmt.Sprintf("%d", uid)] = true + } + } + for _, socket := range sockets { + // filter out sockets based on uids + if len(uids) > 0 { + // XXX: alternatively we could Stat() the socket and + // and check Uid field of stat.Sys().(*syscall.Stat_t), but it's + // more annyoing to unit-test. + userPart := filepath.Base(filepath.Dir(socket)) + if !uids[userPart] { + continue + } + } wg.Add(1) go func(socket string) { defer wg.Done() diff --git a/usersession/client/client_test.go b/usersession/client/client_test.go index 1ee96531f7..40090c8e10 100644 --- a/usersession/client/client_test.go +++ b/usersession/client/client_test.go @@ -448,7 +448,9 @@ func (s *clientSuite) TestServicesStopFailure(c *C) { } func (s *clientSuite) TestPendingRefreshNotification(c *C) { + n := 0 s.handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n++ c.Assert(r.URL.Path, Equals, "/v1/notifications/pending-refresh") w.Header().Set("Content-Type", "application/json") w.WriteHeader(200) @@ -456,4 +458,20 @@ func (s *clientSuite) TestPendingRefreshNotification(c *C) { }) err := s.cli.PendingRefreshNotification(context.Background(), &client.PendingSnapRefreshInfo{}) c.Assert(err, IsNil) + c.Check(n, Equals, 2) +} + +func (s *clientSuite) TestPendingRefreshNotificationOneClient(c *C) { + cli := client.NewForUids(1000) + n := 0 + s.handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n++ + c.Assert(r.URL.Path, Equals, "/v1/notifications/pending-refresh") + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + w.Write([]byte(`{"type": "sync"}`)) + }) + err := cli.PendingRefreshNotification(context.Background(), &client.PendingSnapRefreshInfo{}) + c.Assert(err, IsNil) + c.Check(n, Equals, 1) } From 44d9f7ab1ce7689e4a8e33784e27311b5d98e242 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Sto=C5=82owski?= Date: Wed, 24 Nov 2021 12:06:27 +0100 Subject: [PATCH 2/4] Use atomics for client test to avoid race when the handler is called in parallel (thanks jhenstridge). --- usersession/client/client_test.go | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/usersession/client/client_test.go b/usersession/client/client_test.go index 40090c8e10..299cf5c9b2 100644 --- a/usersession/client/client_test.go +++ b/usersession/client/client_test.go @@ -26,6 +26,7 @@ import ( "net/http" "os" "path/filepath" + "sync/atomic" "testing" "time" @@ -448,9 +449,9 @@ func (s *clientSuite) TestServicesStopFailure(c *C) { } func (s *clientSuite) TestPendingRefreshNotification(c *C) { - n := 0 + var n int32 s.handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - n++ + atomic.AddInt32(&n, 1) c.Assert(r.URL.Path, Equals, "/v1/notifications/pending-refresh") w.Header().Set("Content-Type", "application/json") w.WriteHeader(200) @@ -458,14 +459,14 @@ func (s *clientSuite) TestPendingRefreshNotification(c *C) { }) err := s.cli.PendingRefreshNotification(context.Background(), &client.PendingSnapRefreshInfo{}) c.Assert(err, IsNil) - c.Check(n, Equals, 2) + c.Check(atomic.LoadInt32(&n), Equals, int32(2)) } func (s *clientSuite) TestPendingRefreshNotificationOneClient(c *C) { cli := client.NewForUids(1000) - n := 0 + var n int32 s.handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - n++ + atomic.AddInt32(&n, 1) c.Assert(r.URL.Path, Equals, "/v1/notifications/pending-refresh") w.Header().Set("Content-Type", "application/json") w.WriteHeader(200) @@ -473,5 +474,5 @@ func (s *clientSuite) TestPendingRefreshNotificationOneClient(c *C) { }) err := cli.PendingRefreshNotification(context.Background(), &client.PendingSnapRefreshInfo{}) c.Assert(err, IsNil) - c.Check(n, Equals, 1) + c.Check(atomic.LoadInt32(&n), Equals, int32(1)) } From 014b0acb6677ff2e8935589cee678010db102289 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Sto=C5=82owski?= Date: Wed, 17 Nov 2021 14:01:07 +0100 Subject: [PATCH 3/4] Provide a way for usersession client to send messages only to a subset of clients. --- usersession/client/client.go | 38 +++++++++++++++++++++++++++---- usersession/client/client_test.go | 18 +++++++++++++++ 2 files changed, 51 insertions(+), 5 deletions(-) diff --git a/usersession/client/client.go b/usersession/client/client.go index 602091f7ab..9554257fcc 100644 --- a/usersession/client/client.go +++ b/usersession/client/client.go @@ -52,6 +52,7 @@ func dialSessionAgent(network, address string) (net.Conn, error) { type Client struct { doer *http.Client + uids []int } func New() *Client { @@ -61,6 +62,14 @@ func New() *Client { } } +// NewForUids creates a Client that sends requests to a specific list of uids +// only. +func NewForUids(uids ...int) *Client { + cli := New() + cli.uids = append(cli.uids, uids...) + return cli +} + type Error struct { Kind string `json:"kind"` Value interface{} `json:"value"` @@ -132,11 +141,11 @@ func (client *Client) sendRequest(ctx context.Context, socket string, method, ur return response } -// doMany sends the given request to all active user sessions. Please be -// careful when using this method, because it is not aware of the physical user -// who triggered the request and blindly forwards it to all logged in users. -// Some of them might not have the right to see the request (let alone to -// respond to it). +// doMany sends the given request to all active user sessions or a subset of them +// defined by optional client.uids field. Please be careful when using this +// method, because it is not aware of the physical user who triggered the request +// and blindly forwards it to all logged in users. Some of them might not have +// the right to see the request (let alone to respond to it). func (client *Client) doMany(ctx context.Context, method, urlpath string, query url.Values, headers map[string]string, body []byte) ([]*response, error) { sockets, err := filepath.Glob(filepath.Join(dirs.XdgRuntimeDirGlob, "snapd-session-agent.socket")) if err != nil { @@ -147,7 +156,26 @@ func (client *Client) doMany(ctx context.Context, method, urlpath string, query mu sync.Mutex responses []*response ) + + var uids map[string]bool + if len(client.uids) > 0 { + uids = make(map[string]bool) + for _, uid := range client.uids { + uids[fmt.Sprintf("%d", uid)] = true + } + } + for _, socket := range sockets { + // filter out sockets based on uids + if len(uids) > 0 { + // XXX: alternatively we could Stat() the socket and + // and check Uid field of stat.Sys().(*syscall.Stat_t), but it's + // more annyoing to unit-test. + userPart := filepath.Base(filepath.Dir(socket)) + if !uids[userPart] { + continue + } + } wg.Add(1) go func(socket string) { defer wg.Done() diff --git a/usersession/client/client_test.go b/usersession/client/client_test.go index 1ee96531f7..40090c8e10 100644 --- a/usersession/client/client_test.go +++ b/usersession/client/client_test.go @@ -448,7 +448,9 @@ func (s *clientSuite) TestServicesStopFailure(c *C) { } func (s *clientSuite) TestPendingRefreshNotification(c *C) { + n := 0 s.handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n++ c.Assert(r.URL.Path, Equals, "/v1/notifications/pending-refresh") w.Header().Set("Content-Type", "application/json") w.WriteHeader(200) @@ -456,4 +458,20 @@ func (s *clientSuite) TestPendingRefreshNotification(c *C) { }) err := s.cli.PendingRefreshNotification(context.Background(), &client.PendingSnapRefreshInfo{}) c.Assert(err, IsNil) + c.Check(n, Equals, 2) +} + +func (s *clientSuite) TestPendingRefreshNotificationOneClient(c *C) { + cli := client.NewForUids(1000) + n := 0 + s.handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n++ + c.Assert(r.URL.Path, Equals, "/v1/notifications/pending-refresh") + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + w.Write([]byte(`{"type": "sync"}`)) + }) + err := cli.PendingRefreshNotification(context.Background(), &client.PendingSnapRefreshInfo{}) + c.Assert(err, IsNil) + c.Check(n, Equals, 1) } From a3ac982b20e9b1c12c07e9565b94f9ed6a722482 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Sto=C5=82owski?= Date: Wed, 24 Nov 2021 12:06:27 +0100 Subject: [PATCH 4/4] Use atomics for client test to avoid race when the handler is called in parallel (thanks jhenstridge). --- usersession/client/client_test.go | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/usersession/client/client_test.go b/usersession/client/client_test.go index 40090c8e10..299cf5c9b2 100644 --- a/usersession/client/client_test.go +++ b/usersession/client/client_test.go @@ -26,6 +26,7 @@ import ( "net/http" "os" "path/filepath" + "sync/atomic" "testing" "time" @@ -448,9 +449,9 @@ func (s *clientSuite) TestServicesStopFailure(c *C) { } func (s *clientSuite) TestPendingRefreshNotification(c *C) { - n := 0 + var n int32 s.handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - n++ + atomic.AddInt32(&n, 1) c.Assert(r.URL.Path, Equals, "/v1/notifications/pending-refresh") w.Header().Set("Content-Type", "application/json") w.WriteHeader(200) @@ -458,14 +459,14 @@ func (s *clientSuite) TestPendingRefreshNotification(c *C) { }) err := s.cli.PendingRefreshNotification(context.Background(), &client.PendingSnapRefreshInfo{}) c.Assert(err, IsNil) - c.Check(n, Equals, 2) + c.Check(atomic.LoadInt32(&n), Equals, int32(2)) } func (s *clientSuite) TestPendingRefreshNotificationOneClient(c *C) { cli := client.NewForUids(1000) - n := 0 + var n int32 s.handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - n++ + atomic.AddInt32(&n, 1) c.Assert(r.URL.Path, Equals, "/v1/notifications/pending-refresh") w.Header().Set("Content-Type", "application/json") w.WriteHeader(200) @@ -473,5 +474,5 @@ func (s *clientSuite) TestPendingRefreshNotificationOneClient(c *C) { }) err := cli.PendingRefreshNotification(context.Background(), &client.PendingSnapRefreshInfo{}) c.Assert(err, IsNil) - c.Check(n, Equals, 1) + c.Check(atomic.LoadInt32(&n), Equals, int32(1)) }