using Ryujinx.Graphics.Shader.Decoders;
using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using System.Collections.Generic;

using static Ryujinx.Graphics.Shader.IntermediateRepresentation.OperandHelper;
using static Ryujinx.Graphics.Shader.Translation.Translator;

namespace Ryujinx.Graphics.Shader.Translation
{
    public class TranslatorContext
    {
        private readonly Block[][] _cfg;
        private ShaderConfig _config;

        public ulong Address { get; }

        public ShaderStage Stage => _config.Stage;
        public int Size => _config.Size;

        public FeatureFlags UsedFeatures => _config.UsedFeatures;

        public HashSet<int> TextureHandlesForCache => _config.TextureHandlesForCache;

        public IGpuAccessor GpuAccessor => _config.GpuAccessor;

        internal TranslatorContext(ulong address, Block[][] cfg, ShaderConfig config)
        {
            Address = address;
            _config = config;
            _cfg    = cfg;
        }

        private static bool IsUserAttribute(Operand operand)
        {
            return operand != null &&
                   operand.Type == OperandType.Attribute &&
                   operand.Value >= AttributeConsts.UserAttributeBase &&
                   operand.Value < AttributeConsts.UserAttributeEnd;
        }

        private static FunctionCode[] Combine(FunctionCode[] a, FunctionCode[] b)
        {
            // Here we combine two shaders.
            // For shader A:
            // - All user attribute stores on shader A are turned into copies to a
            // temporary variable. It's assumed that shader B will consume them.
            // - All return instructions are turned into branch instructions, the
            // branch target being the start of the shader B code.
            // For shader B:
            // - All user attribute loads on shader B are turned into copies from a
            // temporary variable, as long that attribute is written by shader A.
            FunctionCode[] output = new FunctionCode[a.Length + b.Length - 1];

            List<Operation> ops = new List<Operation>(a.Length + b.Length);

            Operand[] temps = new Operand[AttributeConsts.UserAttributesCount * 4];

            Operand lblB = Label();

            for (int index = 0; index < a[0].Code.Length; index++)
            {
                Operation operation = a[0].Code[index];

                if (IsUserAttribute(operation.Dest))
                {
                    int tIndex = (operation.Dest.Value - AttributeConsts.UserAttributeBase) / 4;

                    Operand temp = temps[tIndex];

                    if (temp == null)
                    {
                        temp = Local();

                        temps[tIndex] = temp;
                    }

                    operation.Dest = temp;
                }

                if (operation.Inst == Instruction.Return)
                {
                    ops.Add(new Operation(Instruction.Branch, lblB));
                }
                else
                {
                    ops.Add(operation);
                }
            }

            ops.Add(new Operation(Instruction.MarkLabel, lblB));

            for (int index = 0; index < b[0].Code.Length; index++)
            {
                Operation operation = b[0].Code[index];

                for (int srcIndex = 0; srcIndex < operation.SourcesCount; srcIndex++)
                {
                    Operand src = operation.GetSource(srcIndex);

                    if (IsUserAttribute(src))
                    {
                        Operand temp = temps[(src.Value - AttributeConsts.UserAttributeBase) / 4];

                        if (temp != null)
                        {
                            operation.SetSource(srcIndex, temp);
                        }
                    }
                }

                ops.Add(operation);
            }

            output[0] = new FunctionCode(ops.ToArray());

            for (int i = 1; i < a.Length; i++)
            {
                output[i] = a[i];
            }

            for (int i = 1; i < b.Length; i++)
            {
                output[a.Length + i - 1] = b[i];
            }

            return output;
        }

        public ShaderProgram Translate(out ShaderProgramInfo shaderProgramInfo, TranslatorContext other = null)
        {
            FunctionCode[] code = EmitShader(_cfg, _config);

            if (other != null)
            {
                _config.SetUsedFeature(other._config.UsedFeatures);
                TextureHandlesForCache.UnionWith(other.TextureHandlesForCache);

                code = Combine(EmitShader(other._cfg, other._config), code);
            }

            return Translator.Translate(code, _config, out shaderProgramInfo);
        }
    }
}