Skip to content
Open
17 changes: 17 additions & 0 deletions commons/src/main/java/com/powsybl/commons/xml/XmlUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import javanet.staxutils.IndentingXMLStreamWriter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.xml.sax.SAXException;

import javax.xml.XMLConstants;
import javax.xml.parsers.DocumentBuilderFactory;
Expand All @@ -24,6 +25,7 @@
import javax.xml.stream.XMLStreamException;
import javax.xml.stream.XMLStreamReader;
import javax.xml.stream.XMLStreamWriter;
import javax.xml.validation.SchemaFactory;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
Expand Down Expand Up @@ -180,6 +182,21 @@ public static XMLStreamWriter initializeWriter(boolean indent, String indentStri
return initializeWriter(indent, indentString, xmlWriter);
}

public static SchemaFactory newSchemaFactory() {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename to createSchemaFactoryInstance to mirror createXMLInputFactoryInstance

SchemaFactory factory = SchemaFactory.newInstance(XMLConstants.W3C_XML_SCHEMA_NS_URI);
try {
factory.setProperty(XMLConstants.ACCESS_EXTERNAL_SCHEMA, "");
} catch (SAXException e) {
LOGGER.info("- Property unsupported by SchemaFactory implementation: {}", XMLConstants.ACCESS_EXTERNAL_SCHEMA);
}
Comment on lines +187 to +191
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can factorise this code in a small utility method used for the two properties

try {
factory.setProperty(XMLConstants.ACCESS_EXTERNAL_DTD, "");
} catch (SAXException e) {
LOGGER.info("- Property unsupported by SchemaFactory implementation: {}", XMLConstants.ACCESS_EXTERNAL_DTD);
}
return factory;
}

private static XMLStreamWriter initializeWriter(boolean indent, String indentString, XMLStreamWriter initialXmlWriter) throws XMLStreamException {
return initializeWriter(indent, indentString, initialXmlWriter, StandardCharsets.UTF_8);
}
Expand Down
22 changes: 20 additions & 2 deletions commons/src/test/java/com/powsybl/commons/xml/XmlUtilTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@
import com.google.common.collect.ImmutableMap;
import com.powsybl.commons.PowsyblException;
import org.junit.jupiter.api.Test;
import org.xml.sax.SAXNotRecognizedException;
import org.xml.sax.SAXNotSupportedException;

import javax.xml.XMLConstants;
import javax.xml.stream.*;
import javax.xml.validation.SchemaFactory;
import java.io.ByteArrayOutputStream;
import java.io.StringReader;
import java.nio.charset.StandardCharsets;
Expand All @@ -22,8 +26,7 @@
import java.util.concurrent.atomic.AtomicReference;

import static com.powsybl.commons.xml.XmlUtil.getXMLInputFactory;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.*;

/**
* @author Geoffroy Jamgotchian {@literal <geoffroy.jamgotchian at rte-france.com>}
Expand Down Expand Up @@ -188,4 +191,19 @@ void initializeWriter() throws XMLStreamException {
writer.close();
assertEquals("<?xml version=\"1.0\" encoding=\"ISO-8859-1\"?>", baos.toString());
}

@Test
void testSchemaFactory() {
SchemaFactory factory = XmlUtil.newSchemaFactory();
assertNotNull(factory);
try {
Object value1 = factory.getProperty(XMLConstants.ACCESS_EXTERNAL_SCHEMA);
assertEquals("", value1);

Object value2 = factory.getProperty(XMLConstants.ACCESS_EXTERNAL_DTD);
assertEquals("", value2);
} catch (SAXNotSupportedException | SAXNotRecognizedException ignored) {
// ignored
}
}
}
219 changes: 214 additions & 5 deletions iidm/iidm-serde/src/main/java/com/powsybl/iidm/serde/NetworkSerDe.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
import org.xml.sax.SAXException;

import javax.xml.XMLConstants;
import javax.xml.stream.XMLStreamConstants;
import javax.xml.stream.XMLStreamException;
import javax.xml.stream.XMLStreamReader;
import javax.xml.transform.Source;
import javax.xml.transform.stream.StreamSource;
import javax.xml.validation.Schema;
Expand All @@ -53,13 +55,18 @@
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ForkJoinPool;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;

import static com.powsybl.commons.xml.XmlUtil.getXMLInputFactory;
import static com.powsybl.commons.xml.XmlUtil.newSchemaFactory;
import static com.powsybl.iidm.serde.AbstractTreeDataImporter.SUFFIX_MAPPING;
import static com.powsybl.iidm.serde.IidmSerDeConstants.IIDM_PREFIX;
import static com.powsybl.iidm.serde.IidmSerDeConstants.INDENT;
Expand All @@ -85,7 +92,15 @@ public final class NetworkSerDe {
static final byte[] BIIDM_MAGIC_NUMBER = {0x42, 0x69, 0x6e, 0x61, 0x72, 0x79, 0x20, 0x49, 0x49, 0x44, 0x4d};

private static final Supplier<Schema> DEFAULT_SCHEMA_SUPPLIER = Suppliers.memoize(() -> NetworkSerDe.createSchema(DefaultExtensionsSupplier.getInstance()));
private static final Supplier<ConcurrentMap<IidmVersion, Schema>> DEFAULT_SCHEMAS_SUPPLIER = Suppliers.memoize(ConcurrentHashMap::new);
Comment thread
samirromdhani marked this conversation as resolved.

private static final int MAX_NAMESPACE_PREFIX_NUM = 100;
private static final String XSD_RESOURCE_DIR = "/xsd/";
private static final Set<String> ALLOWED_IIDM_XSDS = Stream.of(IidmVersion.values())
.flatMap(v -> v.supportEquipmentValidationLevel()
? Stream.of(v.getXsd(), v.getXsd(false))
: Stream.of(v.getXsd()))
.collect(Collectors.toUnmodifiableSet());

private NetworkSerDe() {
}
Expand All @@ -94,6 +109,10 @@ public static void validate(InputStream is) {
validate(is, DefaultExtensionsSupplier.getInstance());
}

public static void validate(InputStream is, IidmVersion version) {
validate(is, version, DefaultExtensionsSupplier.getInstance());
}

public static void validate(InputStream is, ExtensionsSupplier extensionsSupplier) {
Objects.requireNonNull(extensionsSupplier);
Schema schema;
Expand Down Expand Up @@ -125,19 +144,17 @@ private static Schema createSchema(ExtensionsSupplier extensionsSupplier) {
for (ExtensionSerDe<?, ?> e : extensionsSupplier.get().getProviders()) {
e.getXsdAsStreamList().forEach(xsd -> additionalSchemas.add(new StreamSource(xsd)));
}
SchemaFactory factory = SchemaFactory.newInstance(XMLConstants.W3C_XML_SCHEMA_NS_URI);
SchemaFactory factory = newSchemaFactory();
try {
factory.setProperty(XMLConstants.ACCESS_EXTERNAL_SCHEMA, "");
factory.setProperty(XMLConstants.ACCESS_EXTERNAL_DTD, "");
int length = IidmVersion.values().length + (int) Arrays.stream(IidmVersion.values())
.filter(IidmVersion::supportEquipmentValidationLevel).count();
Source[] sources = new Source[additionalSchemas.size() + length];
int i = 0;
int j = 0;
for (IidmVersion version : IidmVersion.values()) {
sources[i] = new StreamSource(NetworkSerDe.class.getResourceAsStream("/xsd/" + version.getXsd()));
sources[i] = new StreamSource(NetworkSerDe.class.getResourceAsStream(XSD_RESOURCE_DIR + version.getXsd()));
if (version.supportEquipmentValidationLevel()) {
sources[j + IidmVersion.values().length] = new StreamSource(NetworkSerDe.class.getResourceAsStream("/xsd/" + version.getXsd(false)));
sources[j + IidmVersion.values().length] = new StreamSource(NetworkSerDe.class.getResourceAsStream(XSD_RESOURCE_DIR + version.getXsd(false)));
j++;
}
i++;
Expand All @@ -151,6 +168,198 @@ private static Schema createSchema(ExtensionsSupplier extensionsSupplier) {
}
}

public static void validate(InputStream is, IidmVersion version, ExtensionsSupplier extensionsSupplier) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add Javadoc for this method to explain its specificities.

Also, add a warning that it loads the full network file in memory.

Objects.requireNonNull(is);
Objects.requireNonNull(version);
Objects.requireNonNull(extensionsSupplier);

// check version namespace
byte[] xmlBytes;
try {
xmlBytes = is.readAllBytes();
checkNamespace(xmlBytes, version);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does could be outside the try/catch as it does not throw an IOException

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this check even necessary? There might be a better way to do this, avoiding loading the whole file in memory (which you are doing by creating the array of bytes).

For example something like this:

    public static void validate(InputStream is, IidmVersion version, ExtensionsSupplier extensionsSupplier) {
        Objects.requireNonNull(is);
        Objects.requireNonNull(version);
        Objects.requireNonNull(extensionsSupplier);

        // XSD validation
        Schema schema = extensionsSupplier == DefaultExtensionsSupplier.getInstance() ?
            DEFAULT_SCHEMAS_SUPPLIER.get().computeIfAbsent(version, v -> createSchema(DefaultExtensionsSupplier.getInstance(), v)) :
            createSchema(extensionsSupplier, version);
        try {
            XMLFilter xmlFilter = getXMLFilter(version);
            SAXSource saxSource = new SAXSource(xmlFilter, new InputSource(is));
            schema.newValidator().validate(saxSource);
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        } catch (SAXException e) {
            throw new UncheckedSaxException(e);
        } catch (ParserConfigurationException e) {
            throw new UncheckedParserConfigurationException(e);
        }
    }

    private static XMLFilter getXMLFilter(IidmVersion validationVersion) throws ParserConfigurationException, SAXException {
        XMLFilter filter = new XMLFilterImpl() {
            private boolean rootSeen = false;

            @Override
            public void startElement(String uri, String localName, String qName, Attributes atts)
                throws SAXException {
                if (!rootSeen) {
                    checkNamespace(uri, validationVersion);
                    rootSeen = true;
                }
                super.startElement(uri, localName, qName, atts);
            }
        };

        XMLReader xmlReader = SAXParserFactory.newNSInstance().newSAXParser().getXMLReader();
        filter.setParent(xmlReader);
        return filter;
    }

    private static void checkNamespace(String actualNamespace, IidmVersion validationVersion) {
        boolean matches = actualNamespace.equals(validationVersion.getNamespaceURI())
            || validationVersion.supportEquipmentValidationLevel() && actualNamespace.equals(validationVersion.getNamespaceURI(false));
        if (!matches) {
            throw new PowsyblException("Namespace mismatch: expected validation version " + validationVersion.toString(".") + ", found namespace " + actualNamespace);
        }
    }

Note: this is an example, this has to be checked/improved/validated!

} catch (IOException e) {
throw new UncheckedIOException(e);
}
// XSD validation
Schema schema;
if (extensionsSupplier == DefaultExtensionsSupplier.getInstance()) {
schema = DEFAULT_SCHEMAS_SUPPLIER.get().computeIfAbsent(version, v -> createSchema(DefaultExtensionsSupplier.getInstance(), v));
} else {
schema = createSchema(extensionsSupplier, version);
}
Comment on lines +184 to +190
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// XSD validation
Schema schema;
if (extensionsSupplier == DefaultExtensionsSupplier.getInstance()) {
schema = DEFAULT_SCHEMAS_SUPPLIER.get().computeIfAbsent(version, v -> createSchema(DefaultExtensionsSupplier.getInstance(), v));
} else {
schema = createSchema(extensionsSupplier, version);
}
// XSD validation
Schema schema = extensionsSupplier == DefaultExtensionsSupplier.getInstance() ?
DEFAULT_SCHEMAS_SUPPLIER.get().computeIfAbsent(version, v -> createSchema(DefaultExtensionsSupplier.getInstance(), v)) :
createSchema(extensionsSupplier, version);

try {
schema.newValidator().validate(new StreamSource(new ByteArrayInputStream(xmlBytes)));
} catch (IOException e) {
throw new UncheckedIOException(e);
} catch (SAXException e) {
throw new UncheckedSaxException(e);
}
}

private static Schema createSchema(ExtensionsSupplier extensionsSupplier, IidmVersion version) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method could be factorized with the existing one.
You might want to keep working on arrays instead of lists for efficiency (to be tested if possible)

Objects.requireNonNull(extensionsSupplier);
Objects.requireNonNull(version);

SchemaFactory factory = newSchemaFactory();
try {
List<Source> sources = new ArrayList<>();
// iidm: source
sources.add(new StreamSource(NetworkSerDe.class.getResourceAsStream(XSD_RESOURCE_DIR + version.getXsd())));
// equipment: source
if (version.supportEquipmentValidationLevel()) {
sources.add(new StreamSource(NetworkSerDe.class.getResourceAsStream(XSD_RESOURCE_DIR + version.getXsd(false))));
}
// extension: sources
sources.addAll(getExtensionSources(extensionsSupplier, version));

return factory.newSchema(sources.toArray(Source[]::new));
} catch (SAXException e) {
throw new UncheckedSaxException(e);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

/**
* Build the list of XSD required to validate extensions for a given IIDM version.
*
* <p>Some extension XSDs import an IIDM schema through {@code xs:import/@schemaLocation}</p>
* This method parses each supported extension XSD, extracts the schema locations,
* and adds the corresponding IIDM XSD resources.
* <p>Only extensions supported by the provided IIDM version are considered.</p>
*
* @param extensionsSupplier extension provider used to discover available extension
* @param version IIDM version used to filter compatible extensions
* @return list of additional schema sources required by extension
*/
private static List<Source> getExtensionSources(ExtensionsSupplier extensionsSupplier, IidmVersion version) throws IOException {
List<Source> sources = new ArrayList<>();
for (ExtensionSerDe<?, ?> extension : getSupportedExtensionSerDeByIIdmVersion(extensionsSupplier, version)) {
InputStream in = extension.getXsdAsStream();
byte[] extensionXsd = in.readAllBytes();
//required iidm xsd in extension's xsd: source
extractSchemaLocations(extensionXsd)
.forEach(schemaLocation -> sources.add(new StreamSource(NetworkSerDe.class.getResourceAsStream(XSD_RESOURCE_DIR + schemaLocation))));
// extension xsd: source
sources.add(new StreamSource(new ByteArrayInputStream(extensionXsd)));
}
return sources;
}

private static List<ExtensionSerDe<?, ?>> getSupportedExtensionSerDeByIIdmVersion(ExtensionsSupplier extensionsSupplier, IidmVersion version) {
List<ExtensionSerDe<?, ?>> extensions = new ArrayList<>();
for (ExtensionSerDe<?, ?> extensionSerDe : extensionsSupplier.get().getProviders()) {
if (extensionSerDe instanceof AbstractVersionableNetworkExtensionSerDe<?, ?, ?> versionable) {
if (versionable.versionExists(version)) {
extensions.add(extensionSerDe);
}
} else {
// no versionable extensions
extensions.add(extensionSerDe);
}
}
return extensions;
}

private static void checkNamespace(byte[] xmlBytes, IidmVersion validationVersion) {
String actualNs = readRootNamespace(xmlBytes);
boolean matches = actualNs.equals(validationVersion.getNamespaceURI())
|| validationVersion.supportEquipmentValidationLevel() && actualNs.equals(validationVersion.getNamespaceURI(false));
if (!matches) {
throw new PowsyblException("Namespace mismatch: expected validation version " + validationVersion.toString(".") + ", found namespace " + actualNs);
Comment on lines +266 to +270
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
String actualNs = readRootNamespace(xmlBytes);
boolean matches = actualNs.equals(validationVersion.getNamespaceURI())
|| validationVersion.supportEquipmentValidationLevel() && actualNs.equals(validationVersion.getNamespaceURI(false));
if (!matches) {
throw new PowsyblException("Namespace mismatch: expected validation version " + validationVersion.toString(".") + ", found namespace " + actualNs);
String actualNamespace = readRootNamespace(xmlBytes);
boolean matches = actualNamespace.equals(validationVersion.getNamespaceURI())
|| validationVersion.supportEquipmentValidationLevel() && actualNamespace.equals(validationVersion.getNamespaceURI(false));
if (!matches) {
throw new PowsyblException("Namespace mismatch: expected validation version " + validationVersion.toString(".") + ", found namespace " + actualNamespace);

}
}

/**
* Extract {@code xs:import/@schemaLocation} from XSD document
*
* <p>XSD document snippet:</p>
* <pre>{@code
* ...
* targetNamespace="http://www.powsybl.org/schema/iidm/ext/extension-name/1_0"
* xmlns:iidm="http://www.powsybl.org/schema/iidm/1_10">
* <xs:import namespace="http://www.powsybl.org/schema/iidm/1_10" schemaLocation="iidm_V1_10.xsd"/>
* </xs:schema>
* }</pre>
Comment on lines +277 to +284
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea to illustrate with a small example!
Could you do the same with the other methods (when possible and useful)?

*
* @param xsdBytes XSD content as bytes
* @return schema locations found in {@code xs:import}
*/
private static List<String> extractSchemaLocations(byte[] xsdBytes) {
try {
return proceedExtractSchemaLocations(xsdBytes);
} catch (XMLStreamException e) {
throw new UncheckedXmlStreamException(e);
}
}

private static List<String> proceedExtractSchemaLocations(byte[] xsdBytes) throws XMLStreamException {
List<String> locations = new ArrayList<>();
XMLStreamReader reader = null;
try (ByteArrayInputStream in = new ByteArrayInputStream(xsdBytes)) {
reader = getXMLInputFactory().createXMLStreamReader(in);
while (reader.hasNext()) {
int event = reader.next();
if (event == XMLStreamConstants.START_ELEMENT
&& XMLConstants.W3C_XML_SCHEMA_NS_URI.equals(reader.getNamespaceURI())
&& "import".equals(reader.getLocalName())) {
String schemaLocation = reader.getAttributeValue(null, "schemaLocation");
if (schemaLocation != null && !schemaLocation.isBlank() && ALLOWED_IIDM_XSDS.contains(schemaLocation)) {
locations.add(schemaLocation);
}
}
}
return locations;
} catch (XMLStreamException | IOException e) {
throw new PowsyblException("Failed to parse XSD schema", e);
} finally {
if (reader != null) {
reader.close();
}
}
}

/**
* Read the namespace declared on {@code <network>} element
*
* @param xmlBytes XML document content as bytes
* @return Namespace URI
*/
private static String readRootNamespace(byte[] xmlBytes) {
try {
return proceedReadRootNamespace(xmlBytes);
} catch (XMLStreamException e) {
throw new UncheckedXmlStreamException(e);
}
}

private static String proceedReadRootNamespace(byte[] xmlBytes) throws XMLStreamException {
XMLStreamReader reader = null;
try (ByteArrayInputStream in = new ByteArrayInputStream(xmlBytes)) {
reader = getXMLInputFactory().createXMLStreamReader(in);
while (reader.hasNext()) {
if (reader.next() == XMLStreamConstants.START_ELEMENT) {
if (!NETWORK_ROOT_ELEMENT_NAME.equals(reader.getLocalName())) {
throw new PowsyblException("Unexpected root element: " + reader.getLocalName());
}
String ns = reader.getNamespaceURI();
if (ns == null || ns.isBlank()) {
throw new PowsyblException("Missing root namespace");
}
return ns;
}
}
throw new PowsyblException("Missing root namespace");
} catch (XMLStreamException | IOException e) {
throw new PowsyblException("Failed to read namespace from XML", e);
} finally {
if (reader != null) {
reader.close();
}
}
}

private static void throwExceptionIfOption(AbstractOptions<?> options, String message) {
if (options.isThrowExceptionIfExtensionNotFound()) {
throw new PowsyblException(message);
Expand Down
Loading