package be.business.connector.core.utils;

import be.business.connector.core.exceptions.IntegrationModuleException;
import be.business.connector.core.exceptions.XMLGenerationException;
import be.business.connector.core.exceptions.XMLValidationException;
import be.business.connector.core.services.GenericWebserviceCaller;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.xml.sax.SAXException;

import javax.xml.bind.JAXBContext;
import javax.xml.bind.JAXBException;
import javax.xml.bind.Marshaller;
import javax.xml.bind.Unmarshaller;
import javax.xml.transform.*;
import javax.xml.transform.stream.StreamResult;
import javax.xml.transform.stream.StreamSource;
import javax.xml.validation.Schema;
import javax.xml.validation.SchemaFactory;
import javax.xml.validation.Validator;
import java.io.*;
import java.security.Key;
import java.util.HashMap;

import static javax.xml.XMLConstants.W3C_XML_SCHEMA_NS_URI;
import static org.slf4j.LoggerFactory.getLogger;

public class MarshallerHelper<X, Y> {
    private static boolean allowXSDValidation = true;

    private Logger logger = getLogger(MarshallerHelper.class);

    private Unmarshaller unmarshaller;
    private Marshaller marshaller;

    public MarshallerHelper(Class<X> unmarshallClass, Class<Y> marshallClass) {
        try {
            unmarshaller = JAXBContext.newInstance(unmarshallClass).createUnmarshaller();
            marshaller = JAXBContext.newInstance(marshallClass).createMarshaller();
            marshaller.setProperty(Marshaller.JAXB_FORMATTED_OUTPUT, Boolean.TRUE);
        } catch (JAXBException e) {
            throw new IllegalArgumentException("JAXBException " + e);
        }
    }

    public static void withoutXSDValidation(Runnable task) {
        try {
            allowXSDValidation = false;
            task.run();
        } finally {
            allowXSDValidation = true;
        }
    }

    public byte[] toXMLByteArray(Y data) {
        try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) {
            marshaller.marshal(data, bos);
            MessageDumper.getInstance().dump(bos, data.getClass().getSimpleName(), MessageDumper.OUT);
            validate(bos.toByteArray());
            if (payloadLoggingEnabled()) logger.debug(prettyPrint(bos.toByteArray()));
            return bos.toByteArray();
        } catch (JAXBException e) {
            throw new IllegalArgumentException("JAXBException " + e);
        } catch (IOException e) {
            throw new IllegalArgumentException("IOException " + e);
        }
    }

    private void validate(byte[] xmlBytes) {
        if (schemaValidationEnabled() && allowXSDValidation) {
            logger.debug("Validating XML against Schema...");
            SchemaFactory schemaFactory = SchemaFactory.newInstance(W3C_XML_SCHEMA_NS_URI);

            final HashMap<String, String> schemaMapping = new HashMap<>();

            schemaMapping.put("http://www.w3.org/2000/09/xmldsig#", "/META-INF/external/XSD/xmldsig-core-schema.xsd");
            schemaMapping.put("http:/services.recipe.be/core", "/META-INF/recipe-central-system-unsealed/XSD/recipe-core.xsd");

            schemaFactory.setResourceResolver(new ClasspathLSResourceResolver(schemaMapping));

            try (
                    InputStream prescriberSchema = GenericWebserviceCaller.class.getResourceAsStream("/META-INF/recipe-central-system-unsealed/XSD/UnsealedPrescriberServicesV4_schema1.xsd");
                    InputStream patientSchema = GenericWebserviceCaller.class.getResourceAsStream("/META-INF/recipe-central-system-unsealed/XSD/UnsealedPatientServicesV4_schema1.xsd");
                    InputStream executorSchema = GenericWebserviceCaller.class.getResourceAsStream("/META-INF/recipe-central-system-unsealed/XSD/UnsealedExecutorServicesV4_schema1.xsd")
            ) {

                Schema schema =
                        schemaFactory.newSchema(
                                new Source[]{
                                        new StreamSource(prescriberSchema),
                                        new StreamSource(patientSchema),
                                        new StreamSource(executorSchema)
                                });
                Validator validator = schema.newValidator();
                validator.validate(new StreamSource(new ByteArrayInputStream(xmlBytes)));
            } catch (IOException | SAXException e) {
                logger.warn("XML Schema validation failed!", e);
                throw new XMLValidationException(
                        "Failed to validate XML!" + System.lineSeparator() + prettyPrint(xmlBytes), e);
            }
        }
    }

    private String prettyPrint(byte[] xmlBytes) {
        try {
            Transformer transformer = TransformerFactory.newInstance().newTransformer();
            transformer.setOutputProperty(OutputKeys.INDENT, "yes");
            StringWriter writer = new StringWriter();
            transformer.transform(
                    new StreamSource(new ByteArrayInputStream(xmlBytes)), new StreamResult(writer));
            return writer.toString();
        } catch (TransformerException e) {
            throw new XMLValidationException("Failed to pretty print XML!", e);
        }
    }

    private boolean schemaValidationEnabled() {
        String status =
                PropertyHandler.getInstance().getProperty("connector.xml.schema.validation", "disabled");
        logger.debug("XML schema validation " + status + "!");
        return status.equals("enabled");
    }

    private boolean payloadLoggingEnabled() {
        String status =
                PropertyHandler.getInstance()
                        .getProperty("connector.xml.schema.validation.payload.logging", "disabled");
        return status.equals("enabled");
    }

    @SuppressWarnings("unchecked")
    public X toObject(byte[] data) {
        if (responseSchemaValidationEnabled()) validate(data);
        if (payloadLoggingEnabled()) logger.debug(prettyPrint(data));
        try {
            ByteArrayInputStream bis = new ByteArrayInputStream(data);
            X result = (X) unmarshaller.unmarshal(bis);
            bis.close();
            MessageDumper.getInstance().dump(data, result.getClass().getSimpleName(), MessageDumper.IN);
            return result;
        } catch (JAXBException | IOException e) {
            throw new XMLGenerationException(e);
        }
    }

    public X toObject(InputStream inputStream) {
        byte[] data = IOUtils.getBytes(inputStream);
        return toObject(data);
    }

    public String marsh(Y data) throws JAXBException {

        // Create a stringWriter to hold the XML
        StringWriter stringWriter = new StringWriter();

        // Marshal the javaObject and write the XML to the stringWriter
        marshaller.marshal(data, stringWriter);

        return stringWriter.toString();
    }

    @SuppressWarnings("unchecked")
    public X unmarsh(String data) throws JAXBException {
        return (X) unmarshaller.unmarshal(new StringReader(data));
    }

    @SuppressWarnings("unchecked")
    public X unmarsh(byte[] data) throws IntegrationModuleException {
        try {
            return (X) unmarshaller.unmarshal(new ByteArrayInputStream(data));
        } catch (JAXBException e) {
            throw new IntegrationModuleException(
                    I18nHelper.getLabel("error.single.message.validation"), e);
        }
    }

    @SuppressWarnings("unchecked")
    public X unmarshSystemConfiguration(byte[] data) throws IntegrationModuleException {
        try {
            return (X) unmarshaller.unmarshal(new ByteArrayInputStream(data));
        } catch (JAXBException e) {
            throw new IntegrationModuleException(
                    I18nHelper.getLabel("error.systemconfiguration.validation"), e);
        }
    }

    @SuppressWarnings("unchecked")
    public X unmarshProductFilter(byte[] data) throws IntegrationModuleException {
        try {
            return (X) unmarshaller.unmarshal(new ByteArrayInputStream(data));
        } catch (JAXBException e) {
            throw new IntegrationModuleException(
                    I18nHelper.getLabel("error.productFilter.validation"), e);
        }
    }

    public X unsealWithSymmKey(byte[] data, Key symmKey) {
        data = EncryptionUtils.unsealWithSymmKey(symmKey, data);
        return toObject(data);
    }

    private boolean responseSchemaValidationEnabled() {
        String status =
                PropertyHandler.getInstance()
                        .getProperty("connector.xml.schema.response.validation", "disabled");
        logger.debug("XML response schema validation " + status + "!");
        return status.equals("enabled");
    }

    public byte[] unsealWithKey(byte[] data, Key symmKey) {
        return EncryptionUtils.unsealWithSymmKey(symmKey, data);
    }

    public void writePrescriptionToFile(
            byte[] unsealByteWithSymmKeyDecodeAndDecompress, String archivingPath)
            throws IOException, IntegrationModuleException {
        if (StringUtils.isBlank(archivingPath)) {
            throw new IntegrationModuleException(I18nHelper.getLabel("error.archiving.path.missing"));
        }
        FileUtils.writeByteArrayToFile(
                new File(archivingPath), unsealByteWithSymmKeyDecodeAndDecompress);
    }
}
