#!/usr/bin/env python3
#
# Copyright (c) 2023 Red Hat, Inc.
#
# SPDX-License-Identifier: MIT
"""Unit tests for xml-preprocess"""
import contextlib
import importlib
import os
import platform
import subprocess
import tempfile
import unittest
from io import StringIO
xmlpp = importlib.import_module("xml-preprocess")
class TestXmlPreprocess(unittest.TestCase):
    """Tests for xml-preprocess.Preprocessor"""
    def test_preprocess_xml(self):
        with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
            temp_file.write("")
            temp_file_name = temp_file.name
        result = xmlpp.preprocess_xml(temp_file_name)
        self.assertEqual(result, "")
        os.remove(temp_file_name)
    def test_save_xml(self):
        with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
            temp_file_name = temp_file.name
            xmlpp.save_xml("", temp_file_name)
        self.assertTrue(os.path.isfile(temp_file_name))
        os.remove(temp_file_name)
    def test_include(self):
        with tempfile.NamedTemporaryFile(mode="w", delete=False) as inc_file:
            inc_file.write("Content from included file")
            inc_file_name = inc_file.name
        xml_str = f""
        expected = "Content from included file"
        xpp = xmlpp.Preprocessor()
        result = xpp.preprocess(xml_str)
        self.assertEqual(result, expected)
        os.remove(inc_file_name)
        self.assertRaises(FileNotFoundError, xpp.preprocess, xml_str)
    def test_envvar(self):
        os.environ["TEST_ENV_VAR"] = "TestValue"
        xml_str = "$(env.TEST_ENV_VAR)"
        expected = "TestValue"
        xpp = xmlpp.Preprocessor()
        result = xpp.preprocess(xml_str)
        self.assertEqual(result, expected)
        self.assertRaises(KeyError, xpp.preprocess, "$(env.UNKNOWN)")
    def test_sys_var(self):
        xml_str = "$(sys.ARCH)"
        expected = f"{platform.architecture()[0]}"
        xpp = xmlpp.Preprocessor()
        result = xpp.preprocess(xml_str)
        self.assertEqual(result, expected)
        self.assertRaises(KeyError, xpp.preprocess, "$(sys.UNKNOWN)")
    def test_cus_var(self):
        xml_str = "$(var.USER)"
        expected = ""
        xpp = xmlpp.Preprocessor()
        result = xpp.preprocess(xml_str)
        self.assertEqual(result, expected)
        xml_str = "$(var.USER)"
        expected = "FOO"
        xpp = xmlpp.Preprocessor()
        result = xpp.preprocess(xml_str)
        self.assertEqual(result, expected)
    def test_error_warning(self):
        xml_str = ""
        expected = ""
        xpp = xmlpp.Preprocessor()
        out = StringIO()
        with contextlib.redirect_stdout(out):
            result = xpp.preprocess(xml_str)
        self.assertEqual(result, expected)
        self.assertEqual(out.getvalue(), "[Warning]: test warn\n")
        self.assertRaises(RuntimeError, xpp.preprocess, "")
    def test_cmd(self):
        xpp = xmlpp.Preprocessor()
        result = xpp.preprocess('')
        self.assertEqual(result, "hello world")
        self.assertRaises(
            subprocess.CalledProcessError,
            xpp.preprocess, ''
        )
    def test_foreach(self):
        xpp = xmlpp.Preprocessor()
        result = xpp.preprocess(
            '$(var.x)'
        )
        self.assertEqual(result, "abc")
    def test_if_elseif(self):
        xpp = xmlpp.Preprocessor()
        result = xpp.preprocess('ok')
        self.assertEqual(result, "ok")
        result = xpp.preprocess('ok')
        self.assertEqual(result, "")
        result = xpp.preprocess('okko')
        self.assertEqual(result, "ok")
        result = xpp.preprocess('okko')
        self.assertEqual(result, "ko")
        result = xpp.preprocess(
            'okok2ko'
        )
        self.assertEqual(result, "ok2")
        result = xpp.preprocess(
            'okokko'
        )
        self.assertEqual(result, "ko")
    def test_ifdef(self):
        xpp = xmlpp.Preprocessor()
        result = xpp.preprocess('okko')
        self.assertEqual(result, "ko")
        result = xpp.preprocess(
            'okko'
        )
        self.assertEqual(result, "ok")
if __name__ == "__main__":
    unittest.main()