API
 
Loading...
Searching...
No Matches
generateTemplatedCatch2Tests.py
Go to the documentation of this file.
1#!/bin/env python3
2
3'''
4Generate Catch2 tests from template.
5See README.md for more details.
6'''
7
8import os
9import sys
10import subprocess
11import glob
12import re
13import pathlib
14import string
15import random
16import getopt
17
18
19gNextVals = {
20 "string" : 0,
21 "int64" : 0,
22 "uint64" : 0,
23 "int32" : 0,
24 "uint32" : 0,
25 "int16" : 0,
26 "uint16" : 0,
27 "int8" : 0,
28 "uint8" : 0,
29 "float" : 0,
30 "double" : 0
31}
32gIncrementingVals = False
33
34# check jinja2 is installed. install it if not
35try:
36 import jinja2
37except ModuleNotFoundError:
38 print("module 'Jinja2' is not installed. Installing Jinja2...")
39 subprocess.check_call([sys.executable, "-m", "pip", "install", 'Jinja2'])
40 import jinja2
41
42
43'''
44Get base type of log. This is needed for log types that inherit from a base type
45that specfies the messageT(...)
46'''
47def getBaseType(lines : list) -> str:
48 # use regex to find #include "<baseType>.hpp"
49 baseType = ""
50 for line in lines:
51 match = re.search(r'^struct [a-z_]* : public [a-z_]*', line)
52 if match != None:
53 baseType = line.strip().split()[-1]
54 baseType = baseType.split("<")[0]
55
56 return baseType
57
58
59'''
60NOTE: This relies on name order in .fbs schema and .hpp files to be the same.
61'''
62def getSchemaFieldInfo(fname : str) -> tuple[str, tuple] :
63 schemaFolderPath = "./../types/schemas/"
64 schemaFolderPath = os.path.abspath(
65 os.path.join(os.path.dirname(__file__), schemaFolderPath)
66 )
67
68 schemaFilePath = os.path.join(schemaFolderPath, f"{fname}.fbs")
69 if not os.path.isfile(schemaFilePath):
70 return "", tuple()
71
72 schemaFile = open(schemaFilePath, "r")
73
74 schemaFieldInfo = []
75 subTables = dict() # dict where key is sub-table name, value is [(fieldname, type)...]
76 curSubTable = None
77 inTable = False
78 schemaTableName = ""
79 for line in schemaFile:
80 if "table" in line:
81 # check if `table <log_type>_fb`
82 match = re.search(r'^table [a-zA-Z_]*_fb', line)
83 if match != None:
84 line = line.strip().split()
85 tableIdx = line.index("table")
86 schemaTableName = line[tableIdx + 1]
87 # otherwise it is a sub-table
88 else:
89 line = line.strip().split()
90 subNameIdx = line.index("table")
91 subName = line[subNameIdx + 1]
92 subTables[subName] = []
93 curSubTable = subName # we are in a sub-table of the schema
94
95 if not inTable and "{" in line:
96 inTable = True
97 continue
98
99 if inTable:
100 line = line.strip()
101 if ("//" in line):
102 continue
103
104 if ("}" in line):
105 inTable = False
106 curSubTable = None
107 continue
108
109 if (line != ""):
110 lineParts = line.strip().rstrip(";").split(":")
111 name = lineParts[0]
112 type = lineParts[1].split()[0]
113
114 if curSubTable is not None:
115 # add to subtable dict for now, will be added in later
116 subTables[curSubTable].append((name, type))
117 else:
118 schemaFieldInfo.append((name, type))
119 continue
120
121 if len(subTables) == 0:
122 return schemaTableName, tuple(schemaFieldInfo)
123
124
125 # go through sub tables and add them in
126 newSchemaFieldInfo = []
127 for field in schemaFieldInfo:
128 fieldType = field[1]
129 if fieldType in subTables.keys():
130 newSchemaFieldInfo.append({field[0] : subTables[fieldType]})
131 else:
132 newSchemaFieldInfo.append(field)
133 # print(newSchemaFieldInfo)
134 return schemaTableName, tuple(newSchemaFieldInfo)
135
136
137'''
138Quick check that the types in .fbs correspond, mainly strings match to strings,
139and vectors to vectors.
140If they do not correspond, the behavior for comparing the fb values in the tests
141is undefined, and action beyond this generator will need to be taken.
142'''
143def typesCorrespond(fbsType : str, cType : str) -> bool:
144 if ("[" in fbsType) or ("vector" in cType):
145 return ("[" in fbsType) and ("vector" in cType)
146
147 if ("string" in fbsType) or ("string" in cType or "char *" in cType):
148 return (("string" in fbsType) and ("string" in cType or "char *" in cType))
149
150 return True
151
152
153'''
154Check it is not a base log type.
155Must have eventCode and defaultLevel
156'''
157def isValidLogType(lines : list) -> bool:
158 hasEventCode = False
159 hasDefaultLevel = False
160 for line in lines:
161
162 # check event code
163 eventCode = re.search("flatlogs::eventCodeT eventCode = eventCodes::[A-Za-z_0-9]*;", line)
164 if eventCode != None:
165 hasEventCode = True
166
167 # check default level
168 defaultLevel = re.search("flatlogs::logPrioT defaultLevel = flatlogs::logPrio::[A-Za-z_0-9]*;", line)
169 if defaultLevel != None:
170 hasDefaultLevel = True
171
172 # if we have both already, return
173 if hasEventCode and hasDefaultLevel:
174 return True
175
176 return (hasEventCode and hasDefaultLevel)
177
178def makeTestInfoDict(hppFname : str, baseTypesDict : dict) -> dict:
179 returnInfo = dict()
180 headerFile = open(hppFname,"r")
181 headerLines = headerFile.readlines()
182
183 # add name of test/file/type to be generated
184 fNameParts = hppFname.split("/")
185 returnInfo["name"] = fNameParts[-1].strip().split(".")[0]
186 CamelCase = "".join([word.capitalize() for word in returnInfo["name"].split("_")])
187 returnInfo["nameCamelCase"] = CamelCase[0].lower() + CamelCase[1:]
188 # print(f"LOGNAME: {returnInfo["name"]}")
189 returnInfo["genTestFname"] = f"{returnInfo['name']}_generated_tests.cpp"
190 returnInfo["className"] = "C" + "".join([word.capitalize() for word in returnInfo["name"].split("_")])
191 returnInfo["classVarName"] = "".join([word[0].lower() for word in returnInfo["name"].split("_")])
192 returnInfo["baseType"] = getBaseType(headerLines)
193 returnInfo["hasGeneratedHfile"] = hasGeneratedHFile(returnInfo["name"])
194
195 # cannot generate tests from this file alone, need base type
196 if not isValidLogType(headerLines):
197 if returnInfo["name"] not in baseTypesDict:
198 baseTypesDict[returnInfo["name"]] = set()
199
200 return None # don't render anything from this file
201
202 # iterate through all lines in header to:
203 # 1. find where messageT structs are being made -> describes fields
204 # 2. check that is has its own <Get|Create|Verify><name>_fb methods
205 fbMethodName = f"Create{returnInfo["name"][0].upper() + returnInfo["name"][1:]}_fb"
206 hasFbMethods = False
207 messageStructIdxs = []
208 for i in range(len(headerLines)):
209 if "messageT(" in headerLines[i]:
210 messageStructIdxs.append(i)
211 if fbMethodName in headerLines[i]:
212 hasFbMethods = True
213
214 schemaTableName, schemaFieldInfo = getSchemaFieldInfo(returnInfo["name"])
215 returnInfo["schemaTableName"] = schemaTableName
216
217 # handle log types that inherit from base types
218 if len(messageStructIdxs) == 0:
219
220 if returnInfo["baseType"] not in baseTypesDict:
221 baseTypesDict[returnInfo["baseType"]] = set()
222
223 # add inhertied type to dict where val is the base type it inherits from
224 baseTypesDict[returnInfo["baseType"]].add(returnInfo["name"])
225
226 return None # don't render me yet!
227
228 # if it does not have its own fb method, find name of class its using
229 if not hasFbMethods:
230 for line in headerLines:
231 if re.search("^.*Create[a-zA-Z_]*_fb.*$", line) and returnInfo["schemaTableName"] == "":
232 # figure out name of fb methods this type is re-using, e.g. ao_observer -> observer
233 startIndex = line.find("Create") + len("Create")
234 endIndex = line.find("_fb")
235 returnInfo["schemaTableName"] = f"{line[startIndex:endIndex]}_fb"
236
237 returnInfo["messageTypes"] = getMessageFieldInfo(messageStructIdxs, headerLines, schemaFieldInfo)
238
239 return returnInfo
240
241'''
242Parse out field type and name from string
243'''
244def getTypeAndName(lineParts : list) -> tuple[str, str]:
245
246 typeIdxStart = 1 if (lineParts[0] == "const") else 0
247 fieldType = lineParts[typeIdxStart]
248
249 if lineParts[typeIdxStart + 1] == "&":
250 nameIdx = (typeIdxStart + 2)
251 elif lineParts[typeIdxStart + 1] == "*":
252 nameIdx = (typeIdxStart + 2)
253 fieldType += " *"
254 else:
255 nameIdx = (typeIdxStart + 1)
256
257 name = lineParts[nameIdx].rstrip(")").rstrip(",")
258
259 if name[0] == "*":
260 fieldType += " *"
261
262 name = name.lstrip("&*")
263
264 return fieldType, name
265
266'''
267Checks if log type has a corresponding generated .h file in ./types/generated
268'''
269def hasGeneratedHFile(logName : str) -> bool:
270 generatedFolderPath = "./../types/generated/"
271 generatedFolderPath = os.path.abspath(
272 os.path.join(os.path.dirname(__file__), generatedFolderPath)
273 )
274
275 generatedFilePath = os.path.join(generatedFolderPath, f"{logName}_generated.h")
276 if os.path.isfile(generatedFilePath):
277 return True
278
279 return False
280
281def getIntSize(type : str) -> int:
282 intSizeBits = 32 # default size 32 bits
283 if "_t" in type:
284 typeParts = type.split("_t")
285 intSizeBits = int(typeParts[0][-1]) if (int(typeParts[0][-1]) == 8) \
286 else int(typeParts[0][-2:])
287
288 return intSizeBits
289
290
291def getRandInt(type : str) -> int:
292 unsigned = True if "uint" in type else False
293
294 intSizeBits = getIntSize(type)
295
296 if not unsigned:
297 intSizeBits -= 1
298
299 max = (2 ** intSizeBits) - 1
300 min = 0 if unsigned else (0 - max - 1)
301
302 return random.randint(min, max)
303
304def getIncrementingInt(type : str) -> int:
305 intSizeBits = getIntSize(type)
306
307 max = (2 ** intSizeBits) - 1
308
309 if "int8_t" in type:
310 gNextVals["int8"] = (gNextVals["int8"] + 1) % max
311 return gNextVals["int8"]
312 elif "uint8_t" in type:
313 gNextVals["uint8"] = (gNextVals["uint8"] + 1) % max
314 return gNextVals["uint8"]
315 elif "int16_t" in type:
316 gNextVals["int16"] = (gNextVals["int16"] + 1) % max
317 return gNextVals["int16"]
318 elif "uint16_t" in type:
319 gNextVals["uint16"] = (gNextVals["uint16"] + 1) % max
320 return gNextVals["uint16"]
321 elif "int32_t" in type:
322 gNextVals["int32"] = (gNextVals["int32"] + 1) % max
323 return gNextVals["int32"]
324 elif "uint32_t" in type:
325 gNextVals["uint32"] = (gNextVals["uint32"] + 1) % max
326 return gNextVals["uint32"]
327 elif "int64_t" in type:
328 gNextVals["int64"] = (gNextVals["int64"] + 1) % max
329 return gNextVals["int64"]
330 elif "uint64_t" in type:
331 gNextVals["uint64"] = (gNextVals["uint64"] + 1) % max
332 return gNextVals["uint64"]
333 else:
334 gNextVals["int32"] = (gNextVals["int32"] + 1) % max
335 return gNextVals["int32"]
336
337def getTestValFromType(fieldType : str, schemaFieldType = None) -> str:
338 if "bool" in fieldType or (schemaFieldType is not None and "bool" in schemaFieldType):
339 return "1"
340 elif "string" in fieldType or "char *" in fieldType:
341 if gIncrementingVals:
342 gNextVals["string"] += 1
343 return f'"{gNextVals["string"]}"'
344 randString = ''.join(random.choices(string.ascii_lowercase + string.digits, k=10))
345 return f'"{randString}"'
346 elif "int" in fieldType:
347 if gIncrementingVals:
348 return str(getIncrementingInt(fieldType))
349 # need 'u' suffix for randomly generated uint64_t to avoid:
350 # "warning: integer constant is so large that it is unsigned"
351 return f'{str(getRandInt(fieldType))}u' if "uint64_t" in fieldType else str(getRandInt(fieldType))
352 elif "float" in fieldType:
353 if gIncrementingVals:
354 gNextVals["float"] += 1
355 return str(round( (gNextVals["float"] / 100000), 6))
356 return str(round(random.random(), 6))
357 elif "double" in fieldType:
358 if gIncrementingVals:
359 gNextVals["double"] += 1
360 return str(round( (gNextVals["double"] / 10000000000), 14))
361 return str(round(random.random(), 14))
362 else:
363 return "{}"
364
365
366def makeTestVal(fieldDict : dict) -> str:
367 if "vector" in fieldDict["type"]:
368 vals = [ getTestValFromType(fieldDict["vectorType"]) for i in range(10)]
369
370 # special case telem_pokecenter because vector follows specific format
371 if fieldDict["name"] == "pokes" and "vector<float" in fieldDict["type"]:
372 catchAssertVals = [vals[i] for i in range(0, len(vals), 2)]
373 fieldDict["specialAssertVal"] = f"{{ {",".join(catchAssertVals)} }}"
374 return f"{{ {",".join(vals)} }}"
375
376 if "schemaType" in fieldDict:
377 return getTestValFromType(fieldDict["type"], fieldDict["schemaType"])
378
379 return getTestValFromType(fieldDict["type"])
380
381
382
383'''
384make 2d array. each inner array contains dictionaries corresponding to
385the type(s) and name(s) of field(s) in a message:
386[ [ {type : x, name: y ...}, {name: type, ...} ], ... ]
387'''
388def getMessageFieldInfo(messageStructIdxs: list, lines : list, schemaFieldInfo : tuple):
389 msgTypesList = []
390 subTableDictIndex = 0
391
392 # extract log field types and names
393 for i in range(len(messageStructIdxs)):
394 structIdx = messageStructIdxs[i]
395 msgsFieldsList = []
396
397 closed = False
398 fieldCount = 0
399 while not closed and structIdx < len(lines):
400
401 line = lines[structIdx]
402
403 # check if this is a closing line
404 if ")" in line:
405 if ("//" in line and line.find(")") > line.find("//")):
406 # parenthesis is in comment
407 pass
408 elif line.strip().strip(")") == "":
409 break
410 else:
411 closed = True # parse the field, don't leave loop yet
412
413
414 # trim line to just get field info
415 indexStart = (line.find("messageT(") + len("messageT(")) if "messageT(" in line else 0
416
417 indexEnd = len(line)
418 if "//" in line:
419 indexEnd = line.find("//")
420 elif "/*" in line and line.find("/*") < indexEnd:
421 indexEnd = line.find("/*")
422
423 #ignore default argument
424 if "=" in line and line.find("=") < indexEnd:
425 indexEnd = line.find("=")
426
427 line = line[indexStart:indexEnd].strip()
428
429 lineParts = [part.strip().split() for part in line.strip().rstrip(",").split(",")]
430
431 for field in lineParts:
432 fieldDict = {}
433 if len(field) > 0 and "//" in field[0]:
434 break
435 if len(field) == 0:
436 break
437
438 # find type and name
439 type, name = getTypeAndName(field)
440
441 fieldDict["type"] = type
442 fieldDict["name"] = name
443 # get vector type if necessary
444 if "std::vector" in fieldDict["type"]:
445 typeParts = fieldDict["type"].split("<")
446 vectorIdx = [i for i, e in enumerate(typeParts) if "std::vector" in e][0]
447 vectorType = typeParts[vectorIdx + 1].strip(">")
448 fieldDict["vectorType"] = vectorType
449
450 if len(schemaFieldInfo) != 0:
451 if isinstance(schemaFieldInfo[fieldCount], tuple):
452 fieldDict["schemaName"] = schemaFieldInfo[fieldCount][0]
453 fieldDict["schemaType"] = schemaFieldInfo[fieldCount][1]
454 fieldCount += 1
455 else:
456 # go into dictionary..
457 subTableName = next(iter(schemaFieldInfo[fieldCount]))
458 schemaFieldName = schemaFieldInfo[fieldCount][subTableName][subTableDictIndex][0]
459 schemaFieldType = schemaFieldInfo[fieldCount][subTableName][subTableDictIndex][1]
460 fieldDict["schemaName"] = f"{subTableName}()->{schemaFieldName}"
461 fieldDict["schemaType"] = schemaFieldType
462 subTableDictIndex += 1
463 if (subTableDictIndex >= len(schemaFieldInfo[fieldCount][subTableName])):
464 # reset dictionary index if we need to
465 subTableDictIndex = 0
466 fieldCount += 1
467
468 # check schemaType correlates to type in .hpp file
469 if not typesCorrespond(fieldDict["schemaType"], fieldDict["type"]):
470 # if types don't correspond, then use name in messageT and hope for best.
471 # this is why if types are different, then names MUST correspond between
472 # .fbs and .hpp file
473 del fieldDict["schemaName"]
474
475 fieldDict["testVal"] = makeTestVal(fieldDict)
476
477 # add field dict to list of fields
478 msgsFieldsList.append(fieldDict)
479
480 structIdx += 1
481
482 msgTypesList.append(msgsFieldsList)
483
484 return msgTypesList
485
486def makeInheritedTypeInfoDict(typesFolderPath : str, baseName : str, logName : str) -> dict:
487 returnInfo = dict()
488
489 baseFilePath = os.path.join(typesFolderPath, f"{baseName}.hpp")
490 baseHFile = open(baseFilePath,"r")
491
492 # add name of test/file/type to be generated
493 # print(f"LOGNAME: {logName}")
494 returnInfo["name"] = logName
495 returnInfo["genTestFname"] = f"{returnInfo['name']}_generated_tests.cpp"
496 returnInfo["className"] = "C" + "".join([word.capitalize() for word in returnInfo["name"].split("_")])
497 CamelCase = "".join([word.capitalize() for word in returnInfo["name"].split("_")])
498 returnInfo["nameCamelCase"] = CamelCase[0].lower() + CamelCase[1:]
499 returnInfo["classVarName"] = "".join([word[0].lower() for word in returnInfo["name"].split("_")])
500 returnInfo["baseType"] = baseName
501 returnInfo["hasGeneratedHfile"] = hasGeneratedHFile(logName)
502
503
504 baseHLines = baseHFile.readlines()
505
506 # find where messageT structs are being made in base log file -> describes fields
507 messageStructIdxs = []
508 for i in range(len(baseHLines)):
509 if "messageT(" in baseHLines[i]:
510 messageStructIdxs.append(i)
511
512 schemaTableName, schemaFieldInfo = getSchemaFieldInfo(baseName)
513
514 returnInfo["schemaTableName"] = schemaTableName
515 msgFieldInfo = getMessageFieldInfo(messageStructIdxs, baseHLines, schemaFieldInfo)
516
517 returnInfo["messageTypes"] = [[]] if "empty_log" in baseName else msgFieldInfo
518
519 return returnInfo
520
521def versionAsNumber(major, minor):
522 return (major * 1000 + minor)
523
524def main():
525 # check python version >= 3.9
526 if (versionAsNumber(sys.version_info[0], sys.version_info[1]) < versionAsNumber(3,9)):
527 print("Error: Python version must be >= 3.9")
528 exit(0)
529
530
531 global gIncrementingVals
532 gIncrementingVals = False
533
534 # getopt for random seed or incrementing vals
535 try:
536 opts, args = getopt.getopt(sys.argv[1:], "is:")
537 if len(opts) > 1:
538 print("Error: Only one option allowed. -s <seed> or -i for incrementing values.")
539 exit(0)
540
541 except getopt.GetoptError:
542 print("Usage: python3 ./generateTemplatedCatch2Tests.py -s <seed> | -i")
543 exit(0)
544 for opt, arg in opts:
545 if opt in ["-s"]:
546 if not arg.isdigit():
547 print(f"Error: random seed {arg} provided is not an integer.")
548 exit(0)
549 # use random seed if provided with -s
550 random.seed(int(arg))
551 if opt in ["-i"]:
552 gIncrementingVals = True
553
554 # load template
555 env = jinja2.Environment(
556 loader = jinja2.FileSystemLoader(searchpath=os.path.dirname(__file__))
557 )
558 env.trim_blocks = True
559 env.lstrip_blocks = True
560
561 catchTemplate = env.get_template("catch2TestTemplate.jinja2")
562
563 # path to .hpp files here
564 typesFolderPath = "./../types"
565 typesFolderPath = os.path.abspath(
566 os.path.join(os.path.dirname(__file__), typesFolderPath)
567 )
568
569 # generated tests output path
570 generatedTestsFolderPath = "./generated_tests/"
571 generatedTestsFolderPath = os.path.abspath(
572 os.path.join(os.path.dirname(__file__), generatedTestsFolderPath)
573 )
574
575 # make directory if it doesn't exist
576 pathlib.Path(generatedTestsFolderPath).mkdir(exist_ok=True)
577 oldFiles = glob.glob(os.path.join(generatedTestsFolderPath, "*"))
578 for file in oldFiles:
579 os.remove(file)
580
581 types = os.listdir(typesFolderPath)
582 types.sort()
583 baseTypesDict = dict() # map baseTypes to the types that inherit from them
584 for type in types:
585
586 print(type)
587 # check valid type to generate tests for
588 if ".hpp" not in type:
589 continue
590
591 # workaround for software_log issues with source_location
592 if "software" in type:
593 print("software")
594 continue
595
596 # workaround for telsee deprecated fields
597 if "telsee" in type:
598 print("telsee")
599 continue
600
601 typePath = os.path.join(typesFolderPath, type)
602
603 # make dictionary with info for template
604 info = makeTestInfoDict(typePath, baseTypesDict)
605 if (info is None):
606 # empty dictionary, no tests to make
607 continue
608
609 # render
610 renderedHeader = catchTemplate.render(info)
611
612 # write generated file
613 outPath = os.path.join(generatedTestsFolderPath, info["genTestFname"])
614 with open(outPath,"w") as outfile:
615 print(renderedHeader,file=outfile)
616
617 # handle types that inherit from baseTypes
618 for baseType, inheritedTypes in baseTypesDict.items():
619
620 if len(inheritedTypes) == 0:
621 continue
622
623 for inheritedType in inheritedTypes:
624 info = makeInheritedTypeInfoDict(typesFolderPath, baseType, inheritedType)
625 if (info is None):
626 # empty dictionary, no tests to make
627 continue
628
629 # render
630 renderedHeader = catchTemplate.render(info)
631
632 # write generated file
633 outPath = os.path.join(generatedTestsFolderPath, info["genTestFname"])
634 with open(outPath,"w") as outfile:
635 print(renderedHeader,file=outfile)
636
637
638if (__name__ == "__main__"):
639 main()
dict makeInheritedTypeInfoDict(str typesFolderPath, str baseName, str logName)
dict makeTestInfoDict(str hppFname, dict baseTypesDict)
tuple[str, str] getTypeAndName(list lineParts)
str getTestValFromType(str fieldType, schemaFieldType=None)
tuple[str, tuple] getSchemaFieldInfo(str fname)
bool typesCorrespond(str fbsType, str cType)
getMessageFieldInfo(list messageStructIdxs, list lines, tuple schemaFieldInfo)