From 4349a359d42fdfee53b85dd5c89a2f169e1dc6b2 Mon Sep 17 00:00:00 2001
From: Chris Roche <github@rodaine.com>
Date: Wed, 30 Jan 2019 11:41:57 -0800
Subject: [PATCH] Properly handle repeated enums (#140)

---
 templates/cc/enum.go            |  6 +++++-
 templates/cc/register.go        | 21 +++++++++++++--------
 templates/cc/repeated.go        |  2 +-
 tests/harness/cases/enums.proto |  6 ++++++
 tests/harness/executor/cases.go | 12 ++++++++++++
 5 files changed, 37 insertions(+), 10 deletions(-)

diff --git a/templates/cc/enum.go b/templates/cc/enum.go
index 20ae2205..009ded9a 100644
--- a/templates/cc/enum.go
+++ b/templates/cc/enum.go
@@ -6,7 +6,11 @@ const enumTpl = `
 		{{ template "in" . }}
 
 		{{ if $r.GetDefinedOnly }}
-			if (!{{ package $f.Type.Enum }}::{{ (typ $f).Element }}_IsValid({{ accessor . }})) {
+			{{ if $f.Type.IsRepeated }}
+				if (!{{ class $f.Type.Element.Enum }}_IsValid({{ accessor . }})) {
+			{{ else }}
+				if (!{{ class $f.Type.Enum }}_IsValid({{ accessor . }})) {
+			{{ end }}
 				{{ err . "value must be one of the defined enum values" }}
 			}
 		{{ end }}
diff --git a/templates/cc/register.go b/templates/cc/register.go
index 8fba934c..383cae9f 100644
--- a/templates/cc/register.go
+++ b/templates/cc/register.go
@@ -10,8 +10,8 @@ import (
 	"github.com/golang/protobuf/ptypes"
 	"github.com/golang/protobuf/ptypes/duration"
 	"github.com/golang/protobuf/ptypes/timestamp"
-	"github.com/lyft/protoc-gen-star"
-	"github.com/lyft/protoc-gen-star/lang/go"
+	pgs "github.com/lyft/protoc-gen-star"
+	pgsgo "github.com/lyft/protoc-gen-star/lang/go"
 	"github.com/lyft/protoc-gen-validate/templates/shared"
 )
 
@@ -143,15 +143,20 @@ func (fns CCFuncs) hasAccessor(ctx shared.RuleContext) string {
 		fns.methodName(ctx.Field.Name()))
 }
 
-func (fns CCFuncs) classBaseName(msg pgs.Message) string {
-	if m, ok := msg.Parent().(pgs.Message); ok {
-		return fmt.Sprintf("%s_%s", fns.classBaseName(m), msg.Name().String())
+type childEntity interface {
+	pgs.Entity
+	Parent() pgs.ParentEntity
+}
+
+func (fns CCFuncs) classBaseName(ent childEntity) string {
+	if m, ok := ent.Parent().(pgs.Message); ok {
+		return fmt.Sprintf("%s_%s", fns.classBaseName(m), ent.Name().String())
 	}
-	return msg.Name().String()
+	return ent.Name().String()
 }
 
-func (fns CCFuncs) className(msg pgs.Message) string {
-	return fns.packageName(msg) + "::" + fns.classBaseName(msg)
+func (fns CCFuncs) className(ent childEntity) string {
+	return fns.packageName(ent) + "::" + fns.classBaseName(ent)
 }
 
 func (fns CCFuncs) packageName(msg pgs.Entity) string {
diff --git a/templates/cc/repeated.go b/templates/cc/repeated.go
index 31f1424a..f73d069e 100644
--- a/templates/cc/repeated.go
+++ b/templates/cc/repeated.go
@@ -45,7 +45,7 @@ const repTpl = `
 
 	{{ if or $r.GetUnique (ne (.Elem "" "").Typ "none") }}
 		for (int i = 0; i < {{ accessor . }}.size(); i++) {
-			const {{ $typ }}& item = {{ accessor . }}.Get(i);
+			const auto& item = {{ accessor . }}.Get(i);
 			(void)item;
 
 			{{ if $r.GetUnique }}
diff --git a/tests/harness/cases/enums.proto b/tests/harness/cases/enums.proto
index f8cf4341..2281fd4b 100644
--- a/tests/harness/cases/enums.proto
+++ b/tests/harness/cases/enums.proto
@@ -39,3 +39,9 @@ message EnumNotIn      { TestEnum val = 1 [(validate.rules).enum = { not_in: [1]
 message EnumAliasNotIn { TestEnumAlias val = 1 [(validate.rules).enum = { not_in: [1]}]; }
 
 message EnumExternal { other_package.Embed.Enumerated val = 1 [(validate.rules).enum.defined_only = true]; }
+
+message RepeatedEnumDefined { repeated TestEnum val = 1 [(validate.rules).repeated.items.enum.defined_only = true]; }
+message RepeatedExternalEnumDefined { repeated other_package.Embed.Enumerated val = 1 [(validate.rules).repeated.items.enum.defined_only = true]; }
+
+message MapEnumDefined { map<string, TestEnum> val = 1 [(validate.rules).map.values.enum.defined_only = true]; }
+message MapExternalEnumDefined { map<string, other_package.Embed.Enumerated> val = 1 [(validate.rules).map.values.enum.defined_only = true]; }
diff --git a/tests/harness/executor/cases.go b/tests/harness/executor/cases.go
index 51262c5d..235324f1 100644
--- a/tests/harness/executor/cases.go
+++ b/tests/harness/executor/cases.go
@@ -917,6 +917,18 @@ var enumCases = []TestCase{
 
 	{"enum external - defined_only - valid", &cases.EnumExternal{Val: other_package.Embed_VALUE}, true},
 	{"enum external - defined_only - invalid", &cases.EnumExternal{Val: math.MaxInt32}, false},
+
+	{"enum repeated - defined_only - valid", &cases.RepeatedEnumDefined{Val: []cases.TestEnum{cases.TestEnum_ONE, cases.TestEnum_TWO}}, true},
+	{"enum repeated - defined_only - invalid", &cases.RepeatedEnumDefined{Val: []cases.TestEnum{cases.TestEnum_ONE, math.MaxInt32}}, false},
+
+	{"enum repeated (external) - defined_only - valid", &cases.RepeatedExternalEnumDefined{Val: []other_package.Embed_Enumerated{other_package.Embed_VALUE}}, true},
+	{"enum repeated (external) - defined_only - invalid", &cases.RepeatedExternalEnumDefined{Val: []other_package.Embed_Enumerated{math.MaxInt32}}, false},
+
+	{"enum map - defined_only - valid", &cases.MapEnumDefined{Val: map[string]cases.TestEnum{"foo": cases.TestEnum_TWO}}, true},
+	{"enum map - defined_only - invalid", &cases.MapEnumDefined{Val: map[string]cases.TestEnum{"foo": math.MaxInt32}}, false},
+
+	{"enum map (external) - defined_only - valid", &cases.MapExternalEnumDefined{Val: map[string]other_package.Embed_Enumerated{"foo": other_package.Embed_VALUE}}, true},
+	{"enum map (external) - defined_only - invalid", &cases.MapExternalEnumDefined{Val: map[string]other_package.Embed_Enumerated{"foo": math.MaxInt32}}, false},
 }
 
 var messageCases = []TestCase{
-- 
GitLab