Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 52 additions & 54 deletions connector/microsoft/microsoft.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ const (
)

const (
// Microsoft requires this scope to access user's profile
scopeUser = "user.read"
// Microsoft requires this scope to list groups the user is a member of
// and resolve their ids to groups names.
scopeGroups = "directory.read.all"
// Microsoft requires the scopes to start with openid
scopeOpenID = "openid"
// Get the permissions configured on the application registration
// see https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-permissions-and-consent#the-default-scope
scopeDefault = "https://graph.microsoft.com/.default"
// Microsoft requires this scope to return a refresh token
// see https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-permissions-and-consent#offline_access
scopeOfflineAccess = "offline_access"
Expand All @@ -62,7 +62,7 @@ type Config struct {
PromptType string `json:"promptType"`
DomainHint string `json:"domainHint"`

Scopes []string `json:"scopes"` // defaults to scopeUser (user.read)
Scopes []string `json:"scopes"` // defaults to scopeOpenID (openid)
}

// Open returns a strategy for logging in through Microsoft.
Expand Down Expand Up @@ -153,11 +153,9 @@ func (c *microsoftConnector) oauth2Config(scopes connector.Scopes) *oauth2.Confi
if len(c.scopes) > 0 {
microsoftScopes = c.scopes
} else {
microsoftScopes = append(microsoftScopes, scopeUser)
}
if c.groupsRequired(scopes.Groups) {
microsoftScopes = append(microsoftScopes, scopeGroups)
microsoftScopes = append(microsoftScopes, scopeOpenID)
}
microsoftScopes = append(microsoftScopes, scopeDefault)

if scopes.OfflineAccess {
microsoftScopes = append(microsoftScopes, scopeOfflineAccess)
Expand Down Expand Up @@ -386,21 +384,15 @@ func (c *microsoftConnector) user(ctx context.Context, client *http.Client) (u u
// Supports $filter and $orderby.
type group struct {
Name string `json:"displayName"`
Id string `json:"id,omitempty"`
}

func (c *microsoftConnector) getGroups(ctx context.Context, client *http.Client, userID string) ([]string, error) {
userGroups, err := c.getGroupIDs(ctx, client)
userGroups, err := c.queryGroups(ctx, client)
if err != nil {
return nil, err
}

if c.groupNameFormat == GroupName {
userGroups, err = c.getGroupNames(ctx, client, userGroups)
if err != nil {
return nil, err
}
}

// ensure that the user is in at least one required group
filteredGroups := groups_pkg.Filter(userGroups, c.groups)
if len(c.groups) > 0 && len(filteredGroups) == 0 {
Expand All @@ -412,51 +404,26 @@ func (c *microsoftConnector) getGroups(ctx context.Context, client *http.Client,
return userGroups, nil
}

func (c *microsoftConnector) getGroupIDs(ctx context.Context, client *http.Client) (ids []string, err error) {
// https://developer.microsoft.com/en-us/graph/docs/api-reference/v1.0/api/user_getmembergroups
in := &struct {
SecurityEnabledOnly bool `json:"securityEnabledOnly"`
}{c.onlySecurityGroups}
reqURL := c.graphURL + "/v1.0/me/getMemberGroups"
for {
var out []string
var next string

next, err = c.post(ctx, client, reqURL, in, &out)
if err != nil {
return ids, err
}

ids = append(ids, out...)
if next == "" {
return
}
reqURL = next
}
}

func (c *microsoftConnector) getGroupNames(ctx context.Context, client *http.Client, ids []string) (groups []string, err error) {
if len(ids) == 0 {
return
}

// https://developer.microsoft.com/en-us/graph/docs/api-reference/v1.0/api/directoryobject_getbyids
in := &struct {
IDs []string `json:"ids"`
Types []string `json:"types"`
}{ids, []string{"group"}}
reqURL := c.graphURL + "/v1.0/directoryObjects/getByIds"
func (c *microsoftConnector) queryGroups(ctx context.Context, client *http.Client) (groups []string, err error) {
reqURL := c.graphURL + "/v1.0/me/memberOf/microsoft.graph.group?$select=displayName,id"
for {
var out []group
var next string

next, err = c.post(ctx, client, reqURL, in, &out)
next, err = c.get(ctx, client, reqURL, &out)
if err != nil {
c.logger.Info("resolved groups", "groups", groups, "error", err.Error())
return groups, err
}

for _, g := range out {
groups = append(groups, g.Name)
if c.groupNameFormat == GroupName {
c.logger.Info("resolved another group", "name", g.Name)
groups = append(groups, g.Name)
} else {
c.logger.Info("resolved another group", "id", g.Id)
groups = append(groups, g.Id)
}
}
if next == "" {
return
Expand All @@ -466,6 +433,7 @@ func (c *microsoftConnector) getGroupNames(ctx context.Context, client *http.Cli
}

func (c *microsoftConnector) post(ctx context.Context, client *http.Client, reqURL string, in interface{}, out interface{}) (string, error) {
c.logger.Info("post url", "url", reqURL)
var payload bytes.Buffer

err := json.NewEncoder(&payload).Encode(in)
Expand Down Expand Up @@ -500,6 +468,36 @@ func (c *microsoftConnector) post(ctx context.Context, client *http.Client, reqU
return next, nil
}

func (c *microsoftConnector) get(ctx context.Context, client *http.Client, reqURL string, out interface{}) (string, error) {
c.logger.Info("get url", "url", reqURL)

req, err := http.NewRequest("GET", reqURL, nil)
if err != nil {
return "", fmt.Errorf("new req: %v", err)
}
req.Header.Set("Content-Type", "application/json")

resp, err := client.Do(req.WithContext(ctx))
if err != nil {
return "", fmt.Errorf("get URL %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return "", newGraphError(resp.Body)
}

var next string
if err = json.NewDecoder(resp.Body).Decode(&struct {
NextLink *string `json:"@odata.nextLink"`
Value interface{} `json:"value"`
}{&next, out}); err != nil {
return "", fmt.Errorf("JSON decode: %v", err)
}

return next, nil
}

type graphError struct {
Code string `json:"code"`
Message string `json:"message"`
Expand Down
150 changes: 146 additions & 4 deletions connector/microsoft/microsoft_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package microsoft
import (
"encoding/json"
"fmt"
"log/slog"
"net/http"
"net/http/httptest"
"net/url"
Expand Down Expand Up @@ -48,7 +49,7 @@ func TestLoginURL(t *testing.T) {
expectEquals(t, queryParams.Get("client_id"), clientID)
expectEquals(t, queryParams.Get("redirect_uri"), testURL)
expectEquals(t, queryParams.Get("response_type"), "code")
expectEquals(t, queryParams.Get("scope"), "user.read")
expectEquals(t, queryParams.Get("scope"), "openid https://graph.microsoft.com/.default")
expectEquals(t, queryParams.Get("state"), testState)
expectEquals(t, queryParams.Get("prompt"), "")
expectEquals(t, queryParams.Get("domain_hint"), "")
Expand Down Expand Up @@ -104,21 +105,162 @@ func TestUserIdentityFromGraphAPI(t *testing.T) {
func TestUserGroupsFromGraphAPI(t *testing.T) {
s := newTestServer(map[string]testResponse{
"/v1.0/me?$select=id,displayName,userPrincipalName": {data: user{}},
"/v1.0/me/getMemberGroups": {data: map[string]interface{}{
"value": []string{"a", "b"},
"/v1.0/me/memberOf/microsoft.graph.group?$select=displayName,id": {data: map[string]interface{}{
"value": []group{{Name: "a", Id: "1"}, {Name: "b", Id: "2"}},
}},
"/" + tenant + "/oauth2/v2.0/token": dummyToken,
})
defer s.Close()

req, _ := http.NewRequest("GET", s.URL, nil)

c := microsoftConnector{apiURL: s.URL, graphURL: s.URL, tenant: tenant}
c := microsoftConnector{apiURL: s.URL, graphURL: s.URL, tenant: tenant, logger: slog.Default(), groupNameFormat: GroupName}
identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req)
expectNil(t, err)
expectEquals(t, identity.Groups, []string{"a", "b"})
}

func TestUserGroupsWithGroupIDFormat(t *testing.T) {
s := newTestServer(map[string]testResponse{
"/v1.0/me?$select=id,displayName,userPrincipalName": {data: user{}},
"/v1.0/me/memberOf/microsoft.graph.group?$select=displayName,id": {data: map[string]interface{}{
"value": []group{{Name: "GroupA", Id: "id-1"}, {Name: "GroupB", Id: "id-2"}},
}},
"/" + tenant + "/oauth2/v2.0/token": dummyToken,
})
defer s.Close()

req, _ := http.NewRequest("GET", s.URL, nil)

c := microsoftConnector{apiURL: s.URL, graphURL: s.URL, tenant: tenant, logger: slog.Default(), groupNameFormat: GroupID}
identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req)
expectNil(t, err)
expectEquals(t, identity.Groups, []string{"id-1", "id-2"})
}

func TestLoginURLWithCustomScopes(t *testing.T) {
testURL := "https://test.com"
testState := "some-state"
customScopes := []string{"custom.scope1", "custom.scope2"}

conn := microsoftConnector{
apiURL: testURL,
graphURL: testURL,
redirectURI: testURL,
clientID: clientID,
tenant: tenant,
scopes: customScopes,
}

loginURL, _ := conn.LoginURL(connector.Scopes{}, conn.redirectURI, testState)

parsedLoginURL, _ := url.Parse(loginURL)
queryParams := parsedLoginURL.Query()

// Custom scopes should be used, plus the default scope is always appended
expectEquals(t, queryParams.Get("scope"), "custom.scope1 custom.scope2 https://graph.microsoft.com/.default")
}

func TestLoginURLWithOfflineAccess(t *testing.T) {
testURL := "https://test.com"
testState := "some-state"

conn := microsoftConnector{
apiURL: testURL,
graphURL: testURL,
redirectURI: testURL,
clientID: clientID,
tenant: tenant,
}

loginURL, _ := conn.LoginURL(connector.Scopes{OfflineAccess: true}, conn.redirectURI, testState)

parsedLoginURL, _ := url.Parse(loginURL)
queryParams := parsedLoginURL.Query()

expectEquals(t, queryParams.Get("scope"), "openid https://graph.microsoft.com/.default offline_access")
}

func TestUserGroupsWithWhitelist(t *testing.T) {
s := newTestServer(map[string]testResponse{
"/v1.0/me?$select=id,displayName,userPrincipalName": {data: user{ID: "user123"}},
"/v1.0/me/memberOf/microsoft.graph.group?$select=displayName,id": {data: map[string]interface{}{
"value": []group{{Name: "allowed-group", Id: "1"}, {Name: "other-group", Id: "2"}},
}},
"/" + tenant + "/oauth2/v2.0/token": dummyToken,
})
defer s.Close()

req, _ := http.NewRequest("GET", s.URL, nil)

c := microsoftConnector{
apiURL: s.URL,
graphURL: s.URL,
tenant: tenant,
logger: slog.Default(),
groupNameFormat: GroupName,
groups: []string{"allowed-group"},
useGroupsAsWhitelist: true,
}
identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req)
expectNil(t, err)
// Only the whitelisted group should be returned
expectEquals(t, identity.Groups, []string{"allowed-group"})
}

func TestUserGroupsNotInRequiredGroups(t *testing.T) {
s := newTestServer(map[string]testResponse{
"/v1.0/me?$select=id,displayName,userPrincipalName": {data: user{ID: "user123"}},
"/v1.0/me/memberOf/microsoft.graph.group?$select=displayName,id": {data: map[string]interface{}{
"value": []group{{Name: "some-group", Id: "1"}},
}},
"/" + tenant + "/oauth2/v2.0/token": dummyToken,
})
defer s.Close()

req, _ := http.NewRequest("GET", s.URL, nil)

c := microsoftConnector{
apiURL: s.URL,
graphURL: s.URL,
tenant: tenant,
logger: slog.Default(),
groupNameFormat: GroupName,
groups: []string{"required-group"}, // User is not in this group
}
_, err := c.HandleCallback(connector.Scopes{Groups: true}, req)
// Should fail because user is not in required group
if err == nil {
t.Error("Expected error when user is not in required groups")
}
}

func TestUserGroupsInRequiredGroups(t *testing.T) {
s := newTestServer(map[string]testResponse{
"/v1.0/me?$select=id,displayName,userPrincipalName": {data: user{ID: "user123"}},
"/v1.0/me/memberOf/microsoft.graph.group?$select=displayName,id": {data: map[string]interface{}{
"value": []group{{Name: "required-group", Id: "1"}, {Name: "other-group", Id: "2"}},
}},
"/" + tenant + "/oauth2/v2.0/token": dummyToken,
})
defer s.Close()

req, _ := http.NewRequest("GET", s.URL, nil)

c := microsoftConnector{
apiURL: s.URL,
graphURL: s.URL,
tenant: tenant,
logger: slog.Default(),
groupNameFormat: GroupName,
groups: []string{"required-group"},
}
identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req)
expectNil(t, err)
// All groups should be returned (not filtered) when useGroupsAsWhitelist is false
expectEquals(t, identity.Groups, []string{"required-group", "other-group"})
}

func newTestServer(responses map[string]testResponse) *httptest.Server {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response, found := responses[r.RequestURI]
Expand Down