From 4b472947142e17ca4dd5dddbef61369f94dd8539 Mon Sep 17 00:00:00 2001
From: Alexander Schlemmer <alexander@mail-schlemmer.de>
Date: Tue, 8 Mar 2022 09:22:13 +0100
Subject: [PATCH] ENH: added more customization options to plantuml converter

---
 src/caosdb/utils/plantuml.py | 137 ++++++++++++++++++++++-------------
 1 file changed, 86 insertions(+), 51 deletions(-)

diff --git a/src/caosdb/utils/plantuml.py b/src/caosdb/utils/plantuml.py
index 7f83e0b4..2cc1bb53 100644
--- a/src/caosdb/utils/plantuml.py
+++ b/src/caosdb/utils/plantuml.py
@@ -41,6 +41,8 @@ from caosdb.common.datatype import is_reference, get_referenced_recordtype
 
 from typing import Optional
 
+import tempfile
+
 REFERENCE = "REFERENCE"
 
 
@@ -82,13 +84,23 @@ class Grouped(object):
         return self.parents
 
 
-def recordtypes_to_plantuml_string(iterable):
+def recordtypes_to_plantuml_string(iterable,
+                                   add_properties: bool = True,
+                                   add_recordtypes: bool = True,
+                                   add_legend: bool = True,
+                                   style: str = "default"):
     """Converts RecordTypes into a string for PlantUML.
 
     This function obtains an iterable and returns a string which can
     be input into PlantUML for a representation of all RecordTypes in
     the iterable.
 
+    Current options for style
+    -------------------------
+
+    "default" - Standard rectangles with uml class circle and methods section
+    "salexan" - Round rectangles, hide circle and methods section
+
     Current limitations
     -------------------
 
@@ -143,74 +155,89 @@ def recordtypes_to_plantuml_string(iterable):
         return result
 
     result = "@startuml\n\n"
-    result += "skinparam classAttributeIconSize 0\n"
 
-    result += "package Properties #DDDDDD {\n"
+    if style == "default":
+        result += "skinparam classAttributeIconSize 0\n"
+    elif style == "salexan":
+        result += """skinparam roundcorner 20\n
+skinparam boxpadding 20\n
+\n
+hide methods\n
+hide circle\n
+"""
+    else:
+        raise ValueError("Unknown style.")
 
-    for p in properties:
-        inheritances[p] = p.get_parents()
-        dependencies[p] = []
+    
 
-        result += "class \"{klass}\" << (P,#008800) >> {{\n".format(klass=p.name)
+    if add_properties:
+        result += "package Properties #DDDDDD {\n"
+        for p in properties:
+            inheritances[p] = p.get_parents()
+            dependencies[p] = []
 
-        if p.description is not None:
-            result += get_description(p.description)
-        result += "\n..\n"
+            result += "class \"{klass}\" << (P,#008800) >> {{\n".format(klass=p.name)
 
-        if isinstance(p.datatype, str):
-            result += "datatype: " + p.datatype + "\n"
-        elif isinstance(p.datatype, db.Entity):
-            result += "datatype: " + p.datatype.name + "\n"
-        else:
-            result += "datatype: " + str(p.datatype) + "\n"
+            if p.description is not None:
+                result += get_description(p.description)
+            result += "\n..\n"
+
+            if isinstance(p.datatype, str):
+                result += "datatype: " + p.datatype + "\n"
+            elif isinstance(p.datatype, db.Entity):
+                result += "datatype: " + p.datatype.name + "\n"
+            else:
+                result += "datatype: " + str(p.datatype) + "\n"
+            result += "}\n\n"
         result += "}\n\n"
-    result += "}\n\n"
 
-    result += "package RecordTypes #DDDDDD {\n"
+    if add_recordtypes:
+        result += "package RecordTypes #DDDDDD {\n"
 
-    for c in classes:
-        inheritances[c] = c.get_parents()
-        dependencies[c] = []
-        result += "class \"{klass}\" << (C,#FF1111) >> {{\n".format(klass=c.name)
+        for c in classes:
+            inheritances[c] = c.get_parents()
+            dependencies[c] = []
+            result += "class \"{klass}\" << (C,#FF1111) >> {{\n".format(klass=c.name)
 
-        if c.description is not None:
-            result += get_description(c.description)
+            if c.description is not None:
+                result += get_description(c.description)
 
-        props = ""
-        props += _add_properties(c, importance=db.FIX)
-        props += _add_properties(c, importance=db.OBLIGATORY)
-        props += _add_properties(c, importance=db.RECOMMENDED)
-        props += _add_properties(c, importance=db.SUGGESTED)
+            props = ""
+            props += _add_properties(c, importance=db.FIX)
+            props += _add_properties(c, importance=db.OBLIGATORY)
+            props += _add_properties(c, importance=db.RECOMMENDED)
+            props += _add_properties(c, importance=db.SUGGESTED)
 
-        if len(props) > 0:
-            result += "__Properties__\n" + props
-        else:
-            result += "\n..\n"
-        result += "}\n\n"
+            if len(props) > 0:
+                result += "__Properties__\n" + props
+            else:
+                result += "\n..\n"
+            result += "}\n\n"
 
-    for g in grouped:
-        inheritances[g] = g.get_parents()
-        result += "class \"{klass}\" << (G,#0000FF) >> {{\n".format(klass=g.name)
-    result += "}\n\n"
+        for g in grouped:
+            inheritances[g] = g.get_parents()
+            result += "class \"{klass}\" << (G,#0000FF) >> {{\n".format(klass=g.name)
+        result += "}\n\n"
 
-    for c, parents in inheritances.items():
-        for par in parents:
-            result += "\"{par}\" <|-- \"{klass}\"\n".format(
-                klass=c.name, par=par.name)
+        for c, parents in inheritances.items():
+            for par in parents:
+                result += "\"{par}\" <|-- \"{klass}\"\n".format(
+                    klass=c.name, par=par.name)
 
-    for c, deps in dependencies.items():
-        for dep in deps:
-            result += "\"{klass}\" *-- \"{dep}\"\n".format(
-                klass=c.name, dep=dep)
+        for c, deps in dependencies.items():
+            for dep in deps:
+                result += "\"{klass}\" *-- \"{dep}\"\n".format(
+                    klass=c.name, dep=dep)
 
-    result += """
+    if add_legend:
+        result += """
 
 package \"B is a subtype of A\" <<Rectangle>> {
  A <|-right- B
  note  "This determines what you find when you query for the RecordType.\\n'FIND RECORD A' will provide Records which have a parent\\nA or B, while 'FIND RECORD B' will provide only Records which have a parent B." as N1
 }
 """
-    result += """
+        result += """
 
 package \"The property P references an instance of D\" <<Rectangle>> {
  class C {
@@ -295,8 +322,12 @@ def retrieve_substructure(start_record_types, depth, result_id_set=None, result_
 
 def to_graphics(recordtypes: list[db.Entity], filename: str,
                 output_dirname: Optional[str] = None,
-                formats: list[str]=["tsvg"],
-                silent:bool=True):
+                formats: list[str] = ["tsvg"],
+                silent: bool = True,
+                add_properties: bool = True,
+                add_recordtypes: bool = True,
+                add_legend: bool = True,
+                style: str = "default"):
     """Calls recordtypes_to_plantuml_string(), saves result to file and
     creates an svg image
 
@@ -319,7 +350,11 @@ def to_graphics(recordtypes: list[db.Entity], filename: str,
     silent : bool
              Don't output messages.
     """
-    pu = recordtypes_to_plantuml_string(recordtypes)
+    pu = recordtypes_to_plantuml_string(recordtypes,
+                                        add_properties,
+                                        add_recordtypes,
+                                        add_legend,
+                                        style)
 
     if output_dirname is None:
         output_dirname = os.getcwd()
-- 
GitLab