diff --git a/src/caosdb/utils/plantuml.py b/src/caosdb/utils/plantuml.py index be34b2604f3682bb71b48bbd73e00fe854b3af51..2cc1bb53e50aca7b26dc2ade423f010f43f6a774 100644 --- a/src/caosdb/utils/plantuml.py +++ b/src/caosdb/utils/plantuml.py @@ -34,10 +34,15 @@ plantuml FILENAME.pu -> FILENAME.png """ import os +import shutil import caosdb as db from caosdb.common.datatype import is_reference, get_referenced_recordtype +from typing import Optional + +import tempfile + REFERENCE = "REFERENCE" @@ -79,13 +84,23 @@ class Grouped(object): return self.parents -def recordtypes_to_plantuml_string(iterable): +def recordtypes_to_plantuml_string(iterable, + add_properties: bool = True, + add_recordtypes: bool = True, + add_legend: bool = True, + style: str = "default"): """Converts RecordTypes into a string for PlantUML. This function obtains an iterable and returns a string which can be input into PlantUML for a representation of all RecordTypes in the iterable. + Current options for style + ------------------------- + + "default" - Standard rectangles with uml class circle and methods section + "salexan" - Round rectangles, hide circle and methods section + Current limitations ------------------- @@ -140,74 +155,89 @@ def recordtypes_to_plantuml_string(iterable): return result result = "@startuml\n\n" - result += "skinparam classAttributeIconSize 0\n" - result += "package Properties #DDDDDD {\n" + if style == "default": + result += "skinparam classAttributeIconSize 0\n" + elif style == "salexan": + result += """skinparam roundcorner 20\n +skinparam boxpadding 20\n +\n +hide methods\n +hide circle\n +""" + else: + raise ValueError("Unknown style.") + + - for p in properties: - inheritances[p] = p.get_parents() - dependencies[p] = [] + if add_properties: + result += "package Properties #DDDDDD {\n" + for p in properties: + inheritances[p] = p.get_parents() + dependencies[p] = [] - result += "class \"{klass}\" << (P,#008800) >> {{\n".format(klass=p.name) + result += "class \"{klass}\" << (P,#008800) >> {{\n".format(klass=p.name) - if p.description is not None: - result += get_description(p.description) - result += "\n..\n" + if p.description is not None: + result += get_description(p.description) + result += "\n..\n" - if isinstance(p.datatype, str): - result += "datatype: " + p.datatype + "\n" - elif isinstance(p.datatype, db.Entity): - result += "datatype: " + p.datatype.name + "\n" - else: - result += "datatype: " + str(p.datatype) + "\n" + if isinstance(p.datatype, str): + result += "datatype: " + p.datatype + "\n" + elif isinstance(p.datatype, db.Entity): + result += "datatype: " + p.datatype.name + "\n" + else: + result += "datatype: " + str(p.datatype) + "\n" + result += "}\n\n" result += "}\n\n" - result += "}\n\n" - result += "package RecordTypes #DDDDDD {\n" + if add_recordtypes: + result += "package RecordTypes #DDDDDD {\n" - for c in classes: - inheritances[c] = c.get_parents() - dependencies[c] = [] - result += "class \"{klass}\" << (C,#FF1111) >> {{\n".format(klass=c.name) + for c in classes: + inheritances[c] = c.get_parents() + dependencies[c] = [] + result += "class \"{klass}\" << (C,#FF1111) >> {{\n".format(klass=c.name) - if c.description is not None: - result += get_description(c.description) + if c.description is not None: + result += get_description(c.description) - props = "" - props += _add_properties(c, importance=db.FIX) - props += _add_properties(c, importance=db.OBLIGATORY) - props += _add_properties(c, importance=db.RECOMMENDED) - props += _add_properties(c, importance=db.SUGGESTED) + props = "" + props += _add_properties(c, importance=db.FIX) + props += _add_properties(c, importance=db.OBLIGATORY) + props += _add_properties(c, importance=db.RECOMMENDED) + props += _add_properties(c, importance=db.SUGGESTED) - if len(props) > 0: - result += "__Properties__\n" + props - else: - result += "\n..\n" - result += "}\n\n" + if len(props) > 0: + result += "__Properties__\n" + props + else: + result += "\n..\n" + result += "}\n\n" - for g in grouped: - inheritances[g] = g.get_parents() - result += "class \"{klass}\" << (G,#0000FF) >> {{\n".format(klass=g.name) - result += "}\n\n" + for g in grouped: + inheritances[g] = g.get_parents() + result += "class \"{klass}\" << (G,#0000FF) >> {{\n".format(klass=g.name) + result += "}\n\n" - for c, parents in inheritances.items(): - for par in parents: - result += "\"{par}\" <|-- \"{klass}\"\n".format( - klass=c.name, par=par.name) + for c, parents in inheritances.items(): + for par in parents: + result += "\"{par}\" <|-- \"{klass}\"\n".format( + klass=c.name, par=par.name) - for c, deps in dependencies.items(): - for dep in deps: - result += "\"{klass}\" *-- \"{dep}\"\n".format( - klass=c.name, dep=dep) + for c, deps in dependencies.items(): + for dep in deps: + result += "\"{klass}\" *-- \"{dep}\"\n".format( + klass=c.name, dep=dep) - result += """ + if add_legend: + result += """ package \"B is a subtype of A\" <<Rectangle>> { A <|-right- B note "This determines what you find when you query for the RecordType.\\n'FIND RECORD A' will provide Records which have a parent\\nA or B, while 'FIND RECORD B' will provide only Records which have a parent B." as N1 } """ - result += """ + result += """ package \"The property P references an instance of D\" <<Rectangle>> { class C { @@ -263,6 +293,15 @@ def retrieve_substructure(start_record_types, depth, result_id_set=None, result_ if is_reference(prop.datatype) and prop.datatype != db.FILE and depth > 0: rt = db.RecordType(name=get_referenced_recordtype(prop.datatype)).retrieve() retrieve_substructure([rt], depth-1, result_id_set, result_container, False) + + # TODO: clean up this hack + # TODO: make it also work for files + if is_reference(prop.datatype) and prop.value is not None: + r = db.Record(id=prop.value).retrieve() + retrieve_substructure([r], depth-1, result_id_set, result_container, False) + if r.id not in result_id_set: + result_container.append(r) + result_id_set.add(r.id) if prop.id not in result_id_set: result_container.append(prop) @@ -281,7 +320,14 @@ def retrieve_substructure(start_record_types, depth, result_id_set=None, result_ return None -def to_graphics(recordtypes, filename): +def to_graphics(recordtypes: list[db.Entity], filename: str, + output_dirname: Optional[str] = None, + formats: list[str] = ["tsvg"], + silent: bool = True, + add_properties: bool = True, + add_recordtypes: bool = True, + add_legend: bool = True, + style: str = "default"): """Calls recordtypes_to_plantuml_string(), saves result to file and creates an svg image @@ -293,17 +339,48 @@ def to_graphics(recordtypes, filename): Iterable with the entities to be displayed. filename : str filename of the image without the extension(e.g. data_structure; + also without the preceeding path. data_structure.pu and data_structure.svg will be created.) + output_dirname : str + the destination directory for the resulting images as defined by the "-o" + option by plantuml + default is to use current working dir + formats : list[str] + list of target formats as defined by the -t"..." options by plantuml, e.g. "tsvg" + silent : bool + Don't output messages. """ - pu = recordtypes_to_plantuml_string(recordtypes) - - pu_filename = filename+".pu" - with open(pu_filename, "w") as pu_file: - pu_file.write(pu) - - cmd = "plantuml -tsvg %s" % pu_filename - print("Executing:", cmd) - - if os.system(cmd) != 0: - raise Exception("An error occured during the execution of plantuml. " - "Is plantuml installed?") + pu = recordtypes_to_plantuml_string(recordtypes, + add_properties, + add_recordtypes, + add_legend, + style) + + if output_dirname is None: + output_dirname = os.getcwd() + + allowed_formats = [ + "tpng", "tsvg", "teps", "tpdf", "tvdx", "txmi", + "tscxml", "thtml", "ttxt", "tutxt", "tlatex", "tlatex:nopreamble"] + + with tempfile.TemporaryDirectory() as td: + + pu_filename = os.path.join(td, filename + ".pu") + with open(pu_filename, "w") as pu_file: + pu_file.write(pu) + + for format in formats: + extension = format[1:] + if ":" in extension: + extension = extension[:extension.index(":")] + + if format not in allowed_formats: + raise RuntimeError("Format not allowed.") + cmd = "plantuml -{} {}".format(format, pu_filename) + if not silent: + print("Executing:", cmd) + + if os.system(cmd) != 0: # TODO: replace with subprocess.run + raise Exception("An error occured during the execution of plantuml. " + "Is plantuml installed?") + shutil.copy(os.path.join(td, filename + "." + extension), output_dirname)