Merge pull request #11072 from stolowski/notifications/client-user-subset

usersession/client: provide a way for client to send messages to a subset of users
This commit is contained in:
Pawel Stolowski
2021-12-03 12:44:34 +01:00
committed by GitHub
2 changed files with 52 additions and 5 deletions

View File

@@ -52,6 +52,7 @@ func dialSessionAgent(network, address string) (net.Conn, error) {
type Client struct {
doer *http.Client
uids []int
}
func New() *Client {
@@ -61,6 +62,14 @@ func New() *Client {
}
}
// NewForUids creates a Client that sends requests to a specific list of uids
// only.
func NewForUids(uids ...int) *Client {
cli := New()
cli.uids = append(cli.uids, uids...)
return cli
}
type Error struct {
Kind string `json:"kind"`
Value interface{} `json:"value"`
@@ -132,11 +141,11 @@ func (client *Client) sendRequest(ctx context.Context, socket string, method, ur
return response
}
// doMany sends the given request to all active user sessions. Please be
// careful when using this method, because it is not aware of the physical user
// who triggered the request and blindly forwards it to all logged in users.
// Some of them might not have the right to see the request (let alone to
// respond to it).
// doMany sends the given request to all active user sessions or a subset of them
// defined by optional client.uids field. Please be careful when using this
// method, because it is not aware of the physical user who triggered the request
// and blindly forwards it to all logged in users. Some of them might not have
// the right to see the request (let alone to respond to it).
func (client *Client) doMany(ctx context.Context, method, urlpath string, query url.Values, headers map[string]string, body []byte) ([]*response, error) {
sockets, err := filepath.Glob(filepath.Join(dirs.XdgRuntimeDirGlob, "snapd-session-agent.socket"))
if err != nil {
@@ -147,7 +156,26 @@ func (client *Client) doMany(ctx context.Context, method, urlpath string, query
mu sync.Mutex
responses []*response
)
var uids map[string]bool
if len(client.uids) > 0 {
uids = make(map[string]bool)
for _, uid := range client.uids {
uids[fmt.Sprintf("%d", uid)] = true
}
}
for _, socket := range sockets {
// filter out sockets based on uids
if len(uids) > 0 {
// XXX: alternatively we could Stat() the socket and
// and check Uid field of stat.Sys().(*syscall.Stat_t), but it's
// more annyoing to unit-test.
userPart := filepath.Base(filepath.Dir(socket))
if !uids[userPart] {
continue
}
}
wg.Add(1)
go func(socket string) {
defer wg.Done()

View File

@@ -26,6 +26,7 @@ import (
"net/http"
"os"
"path/filepath"
"sync/atomic"
"testing"
"time"
@@ -448,7 +449,9 @@ func (s *clientSuite) TestServicesStopFailure(c *C) {
}
func (s *clientSuite) TestPendingRefreshNotification(c *C) {
var n int32
s.handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&n, 1)
c.Assert(r.URL.Path, Equals, "/v1/notifications/pending-refresh")
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
@@ -456,4 +459,20 @@ func (s *clientSuite) TestPendingRefreshNotification(c *C) {
})
err := s.cli.PendingRefreshNotification(context.Background(), &client.PendingSnapRefreshInfo{})
c.Assert(err, IsNil)
c.Check(atomic.LoadInt32(&n), Equals, int32(2))
}
func (s *clientSuite) TestPendingRefreshNotificationOneClient(c *C) {
cli := client.NewForUids(1000)
var n int32
s.handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&n, 1)
c.Assert(r.URL.Path, Equals, "/v1/notifications/pending-refresh")
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
w.Write([]byte(`{"type": "sync"}`))
})
err := cli.PendingRefreshNotification(context.Background(), &client.PendingSnapRefreshInfo{})
c.Assert(err, IsNil)
c.Check(atomic.LoadInt32(&n), Equals, int32(1))
}