Skip to content

Commit

Permalink
Merge pull request casbin#610 from nodece/beta
Browse files Browse the repository at this point in the history
fix: compatible with Model of v2
  • Loading branch information
hsluoyz committed Oct 4, 2020
2 parents 21b57de + 561c2d1 commit d9f533d
Show file tree
Hide file tree
Showing 11 changed files with 109 additions and 166 deletions.
12 changes: 6 additions & 6 deletions enforcer.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import (
// Enforcer is the main interface for authorization enforcement and policy management.
type Enforcer struct {
modelPath string
model *model.Model
model model.Model
fm model.FunctionMap
eft effect.Effector

Expand Down Expand Up @@ -101,7 +101,7 @@ func NewEnforcer(params ...interface{}) (*Enforcer, error) {
case string:
return nil, errors.New("invalid parameters for enforcer")
default:
err := e.InitWithModelAndAdapter(p0.(*model.Model), params[1].(persist.Adapter))
err := e.InitWithModelAndAdapter(p0.(model.Model), params[1].(persist.Adapter))
if err != nil {
return nil, err
}
Expand All @@ -115,7 +115,7 @@ func NewEnforcer(params ...interface{}) (*Enforcer, error) {
return nil, err
}
default:
err := e.InitWithModelAndAdapter(p0.(*model.Model), nil)
err := e.InitWithModelAndAdapter(p0.(model.Model), nil)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -152,7 +152,7 @@ func (e *Enforcer) InitWithAdapter(modelPath string, adapter persist.Adapter) er
}

// InitWithModelAndAdapter initializes an enforcer with a model and a database adapter.
func (e *Enforcer) InitWithModelAndAdapter(m *model.Model, adapter persist.Adapter) error {
func (e *Enforcer) InitWithModelAndAdapter(m model.Model, adapter persist.Adapter) error {
e.adapter = adapter

e.model = m
Expand Down Expand Up @@ -201,12 +201,12 @@ func (e *Enforcer) LoadModel() error {
}

// GetModel gets the current model.
func (e *Enforcer) GetModel() *model.Model {
func (e *Enforcer) GetModel() model.Model {
return e.model
}

// SetModel sets the current model.
func (e *Enforcer) SetModel(m *model.Model) {
func (e *Enforcer) SetModel(m model.Model) {
e.model = m
e.fm = model.LoadFunctionMap()
e.internal = internal.NewPolicyManager(m, e.adapter, e.rm)
Expand Down
6 changes: 3 additions & 3 deletions enforcer_interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ type IEnforcer interface {
/* Enforcer API */
InitWithFile(modelPath string, policyPath string) error
InitWithAdapter(modelPath string, adapter persist.Adapter) error
InitWithModelAndAdapter(m *model.Model, adapter persist.Adapter) error
InitWithModelAndAdapter(m model.Model, adapter persist.Adapter) error
LoadModel() error
GetModel() *model.Model
SetModel(m *model.Model)
GetModel() model.Model
SetModel(m model.Model)
GetAdapter() persist.Adapter
SetAdapter(adapter persist.Adapter)
SetWatcher(watcher persist.Watcher) error
Expand Down
4 changes: 2 additions & 2 deletions internal/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type PolicyManager interface {
}

type policyManager struct {
model *model.Model
model model.Model
adapter persist.Adapter
rm rbac.RoleManager
}
Expand All @@ -41,7 +41,7 @@ const (
)

// NewPolicyManager is the constructor for PolicyManager
func NewPolicyManager(model *model.Model, adapter persist.Adapter, rm rbac.RoleManager) PolicyManager {
func NewPolicyManager(model model.Model, adapter persist.Adapter, rm rbac.RoleManager) PolicyManager {
return &policyManager{
model: model,
adapter: adapter,
Expand Down
89 changes: 33 additions & 56 deletions model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"fmt"
"strconv"
"strings"
"sync"

"github.com/Knetic/govaluate"
"github.com/casbin/casbin/v3/rbac"
Expand All @@ -29,10 +28,7 @@ import (
)

// Model represents the whole access control model.
type Model struct {
data map[string]AssertionMap
mutex sync.RWMutex
}
type Model map[string]AssertionMap

// AssertionMap is the collection of assertions, can be "r", "p", "g", "e", "m".
type AssertionMap map[string]*Assertion
Expand All @@ -48,19 +44,17 @@ var sectionNameMap = map[string]string{
// Minimal required sections for a model to be valid
var requiredSections = []string{"r", "p", "e", "m"}

func loadAssertion(model *Model, cfg config.ConfigInterface, sec string, key string) bool {
func loadAssertion(model Model, cfg config.ConfigInterface, sec string, key string) bool {
value := cfg.String(sectionNameMap[sec] + "::" + key)
return model.addDef(sec, key, value)
}

// AddDef adds an assertion to the model.
func (model *Model) AddDef(sec string, key string, value string) bool {
model.mutex.Lock()
defer model.mutex.Unlock()
func (model Model) AddDef(sec string, key string, value string) bool {
return model.addDef(sec, key, value)
}

func (model *Model) addDef(sec string, key string, value string) bool {
func (model Model) addDef(sec string, key string, value string) bool {
if value == "" {
return false
}
Expand All @@ -79,12 +73,12 @@ func (model *Model) addDef(sec string, key string, value string) bool {
ast.Value = util.RemoveComments(util.EscapeAssertion(ast.Value))
}

_, ok := model.data[sec]
_, ok := model[sec]
if !ok {
model.data[sec] = make(AssertionMap)
model[sec] = make(AssertionMap)
}

model.data[sec][key] = &ast
model[sec][key] = &ast
return true
}

Expand All @@ -96,7 +90,7 @@ func getKeySuffix(i int) string {
return strconv.Itoa(i)
}

func loadSection(model *Model, cfg config.ConfigInterface, sec string) {
func loadSection(model Model, cfg config.ConfigInterface, sec string) {
i := 1
for {
if !loadAssertion(model, cfg, sec, sec+getKeySuffix(i)) {
Expand All @@ -108,14 +102,13 @@ func loadSection(model *Model, cfg config.ConfigInterface, sec string) {
}

// NewModel creates an empty model.
func NewModel() *Model {
m := new(Model)
m.data = make(map[string]AssertionMap)
func NewModel() Model {
m := make(Model)
return m
}

// NewModelFromFile creates a model from a .CONF file.
func NewModelFromFile(path string) (*Model, error) {
func NewModelFromFile(path string) (Model, error) {
m := NewModel()

err := m.LoadModel(path)
Expand All @@ -127,7 +120,7 @@ func NewModelFromFile(path string) (*Model, error) {
}

// NewModelFromString creates a model from a string which contains model text.
func NewModelFromString(text string) (*Model, error) {
func NewModelFromString(text string) (Model, error) {
m := NewModel()

err := m.LoadModelFromText(text)
Expand All @@ -139,7 +132,7 @@ func NewModelFromString(text string) (*Model, error) {
}

// LoadModel loads the model from model CONF file.
func (model *Model) LoadModel(path string) error {
func (model Model) LoadModel(path string) error {
cfg, err := config.NewConfig(path)
if err != nil {
return err
Expand All @@ -149,7 +142,7 @@ func (model *Model) LoadModel(path string) error {
}

// LoadModelFromText loads the model from the text.
func (model *Model) LoadModelFromText(text string) error {
func (model Model) LoadModelFromText(text string) error {
cfg, err := config.NewConfigFromText(text)
if err != nil {
return err
Expand All @@ -158,9 +151,7 @@ func (model *Model) LoadModelFromText(text string) error {
return model.loadModelFromConfig(cfg)
}

func (model *Model) loadModelFromConfig(cfg config.ConfigInterface) error {
model.mutex.Lock()
defer model.mutex.Unlock()
func (model Model) loadModelFromConfig(cfg config.ConfigInterface) error {
for s := range sectionNameMap {
loadSection(model, cfg, s)
}
Expand All @@ -176,75 +167,61 @@ func (model *Model) loadModelFromConfig(cfg config.ConfigInterface) error {
return nil
}

func (model *Model) hasSection(sec string) bool {
section := model.data[sec]
func (model Model) hasSection(sec string) bool {
section := model[sec]
return section != nil
}

// PrintModel prints the model to the log.
func (model *Model) PrintModel() {
model.mutex.RLock()
defer model.mutex.RUnlock()
func (model Model) PrintModel() {
log.LogPrint("Model:")
for k, v := range model.data {
for k, v := range model {
for i, j := range v {
log.LogPrintf("%s.%s: %s", k, i, j.Value)
}
}
}

// GetMatcher gets the matcher.
func (model *Model) GetMatcher() string {
model.mutex.RLock()
defer model.mutex.RUnlock()
return model.data["m"]["m"].Value
func (model Model) GetMatcher() string {
return model["m"]["m"].Value
}

// GetEffectExpression gets the effect expression.
func (model *Model) GetEffectExpression() string {
model.mutex.RLock()
defer model.mutex.RUnlock()
return model.data["e"]["e"].Value
func (model Model) GetEffectExpression() string {
return model["e"]["e"].Value
}

// GetRoleManager gets the current role manager used in ptype.
func (model *Model) GetRoleManager(sec string, ptype string) rbac.RoleManager {
model.mutex.RLock()
defer model.mutex.RUnlock()
return model.data[sec][ptype].RM
func (model Model) GetRoleManager(sec string, ptype string) rbac.RoleManager {
return model[sec][ptype].RM
}

// GetTokens returns a map with all the tokens
func (model *Model) GetTokens(sec string, ptype string) map[string]int {
model.mutex.RLock()
defer model.mutex.RUnlock()
tokens := make(map[string]int, len(model.data[sec][ptype].Tokens))
for i, token := range model.data[sec][ptype].Tokens {
func (model Model) GetTokens(sec string, ptype string) map[string]int {
tokens := make(map[string]int, len(model[sec][ptype].Tokens))
for i, token := range model[sec][ptype].Tokens {
tokens[token] = i
}

return tokens
}

// GetPtypes returns a slice for all ptype
func (model *Model) GetPtypes(sec string) []string {
model.mutex.RLock()
defer model.mutex.RUnlock()
func (model Model) GetPtypes(sec string) []string {
var res []string
for k := range model.data[sec] {
for k := range model[sec] {
res = append(res, k)
}
return res
}

// GenerateFunctions return a map with all the functions
func (model *Model) GenerateFunctions(fm FunctionMap) map[string]govaluate.ExpressionFunction {
model.mutex.RLock()
defer model.mutex.RUnlock()
func (model Model) GenerateFunctions(fm FunctionMap) map[string]govaluate.ExpressionFunction {
functions := fm.GetFunctions()

if _, ok := model.data["g"]; ok {
for key, ast := range model.data["g"] {
if _, ok := model["g"]; ok {
for key, ast := range model["g"] {
rm := ast.RM
functions[key] = util.GenerateGFunction(rm)
}
Expand Down
Loading

0 comments on commit d9f533d

Please sign in to comment.