// Copyright 2017 The Clspv Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "llvm/IR/Constants.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/Pass.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Utils/Cloning.h"

#include "Passes.h"

using namespace llvm;

#define DEBUG_TYPE "undosret"

namespace {
struct UndoSRetPass : public ModulePass {
  static char ID;
  UndoSRetPass() : ModulePass(ID) {}

  bool runOnModule(Module &M) override;
};
} // namespace

char UndoSRetPass::ID = 0;
INITIALIZE_PASS(UndoSRetPass, "UndoSRet", "Undo SRet Pass", false, false)

namespace clspv {
llvm::ModulePass *createUndoSRetPass() { return new UndoSRetPass(); }
} // namespace clspv

bool UndoSRetPass::runOnModule(Module &M) {
  bool Changed = false;
  LLVMContext &Context = M.getContext();

  SmallVector<Function *, 8> WorkList;
  for (Function &F : M) {
    if (F.isDeclaration()) {
      continue;
    }

    if (F.getReturnType()->isVoidTy()) {
      for (Argument &Arg : F.args()) {
        // Check sret attribute.
        if (Arg.hasStructRetAttr()) {
          // We found a function that needs to be modified!
          WorkList.push_back(&F);
          Changed = true;
        }
      }
    }
  }

  for (Function *F : WorkList) {
    auto InsertPoint = F->getEntryBlock().getFirstNonPHIOrDbg();

    for (Argument &Arg : F->args()) {
      // Check sret attribute.
      if (Arg.hasStructRetAttr()) {
        PointerType *PTy = cast<PointerType>(Arg.getType());
        Type *RetTy = PTy->getElementType();
        // Create alloca instruction for return value on function's entry
        // block.
        AllocaInst *RetVal =
            new AllocaInst(RetTy, 0, nullptr, "retval", InsertPoint);

        // Change arg's users with retval.
        Arg.replaceAllUsesWith(RetVal);

        // Create new function type with real return type instead of sret
        // argument.
        SmallVector<Type *, 8> NewFuncParamTys;
        for (const auto &Arg : F->args()) {
          // Ignore argument with sret attribute.
          if (Arg.hasStructRetAttr()) {
            continue;
          }
          NewFuncParamTys.push_back(Arg.getType());
        }
        FunctionType *NewFuncTy =
            FunctionType::get(RetTy, NewFuncParamTys, false);

        // Create new function.
        Function *NewFunc = Function::Create(NewFuncTy, F->getLinkage());
        NewFunc->takeName(F);

        // Insert the function just after the original to preserve the ordering
        // of the functions within the module.
        auto &FunctionList = M.getFunctionList();

        for (auto Iter = FunctionList.begin(), IterEnd = FunctionList.end();
             Iter != IterEnd; ++Iter) {
          // If we find our functions place in the iterator.
          if (&*Iter == F) {
            FunctionList.insertAfter(Iter, NewFunc);
            break;
          }
        }

        // Map original function's arguments to new function's arguments.
        ValueToValueMapTy VMap;
        auto NewArg = NewFunc->arg_begin();
        for (auto &Arg : F->args()) {
          if (Arg.hasStructRetAttr()) {
            VMap[&Arg] = UndefValue::get(Arg.getType());
            continue;
          }
          NewArg->setName(Arg.getName());
          VMap[&Arg] = &*(NewArg++);
        }

        // Clone original function into new function.
        SmallVector<ReturnInst *, 4> RetInsts;
        CloneFunctionInto(NewFunc, F, VMap,
                          CloneFunctionChangeType::LocalChangesOnly, RetInsts);

        // Change return instruction like this.
        //
        // %retv = load %retval;
        // ret %retv;
        for (auto Ret : RetInsts) {
          LoadInst *LD = new LoadInst(RetTy, VMap[RetVal], "", Ret);
          ReturnInst *NewRet = ReturnInst::Create(Context, LD, Ret);
          Ret->replaceAllUsesWith(NewRet);
          Ret->eraseFromParent();
        }

        SmallVector<User *, 8> ToRemoves;

        // Update caller site.
        for (auto User : F->users()) {
          if (CallInst *Call = dyn_cast<CallInst>(User)) {
            // Create new call instruction for new function without sret.
            SmallVector<Value *, 8> NewArgs(Call->arg_begin() + 1,
                                            Call->arg_end());
            CallInst *NewCall = CallInst::Create(NewFunc, NewArgs, "", Call);

            NewCall->takeName(Call);
            NewCall->setCallingConv(Call->getCallingConv());
            NewCall->setDebugLoc(Call->getDebugLoc());

            // Copy attributes over, but skip the attributes for the first
            // parameter since it is removed.  In particular, the old
            // first parameter has a StructRet attribute that should disappear.
            auto attrs(Call->getAttributes());
            AttributeList new_attrs(
                AttributeList::get(Context, AttributeList::FunctionIndex,
                                   AttrBuilder(attrs.getFnAttributes())));
            new_attrs =
                new_attrs.addAttributes(Context, AttributeList::ReturnIndex,
                                        AttrBuilder(attrs.getRetAttributes()));
            for (unsigned i = 1; i < Call->getNumArgOperands(); i++) {
              new_attrs = new_attrs.addParamAttributes(
                  Context, i - 1, AttrBuilder(attrs.getParamAttributes(i)));
            }
            NewCall->setAttributes(new_attrs);

            // Store the value we returned from our function call into the
            // the orignal destination.
            new StoreInst(NewCall, Call->getArgOperand(0), Call);
          }

          ToRemoves.push_back(User);
        }

        for (User *U : ToRemoves) {
          U->dropAllReferences();
          if (Instruction *I = dyn_cast<Instruction>(U)) {
            I->eraseFromParent();
          }
        }

        // We found the argument that had sret, so we are done with this
        // function!
        break;
      }
    }

    // Delete original functions with sret argument.
    F->dropAllReferences();
    F->eraseFromParent();
  }

  return Changed;
}
